From 5d66b7b0833569680bd72b55ed3056a4726bd7fe Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Tue, 23 Sep 2025 17:10:12 -0400 Subject: [PATCH 01/20] feat(rig-951): generic HTTP client --- Cargo.lock | 1 + rig-core/Cargo.toml | 1 + rig-core/src/client/mod.rs | 49 ++-- rig-core/src/client/verify.rs | 7 +- rig-core/src/completion/chat.rs | 203 ++++++++++++++++ rig-core/src/completion/request.rs | 4 +- rig-core/src/embeddings/embedding.rs | 2 +- rig-core/src/http_client.rs | 147 ++++++++++++ rig-core/src/lib.rs | 1 + rig-core/src/providers/anthropic/client.rs | 216 ++++++++++++------ .../src/providers/anthropic/completion.rs | 64 ++++-- .../src/providers/anthropic/decoders/sse.rs | 23 +- rig-core/src/providers/anthropic/streaming.rs | 49 ++-- rig-core/src/providers/cohere/client.rs | 133 ++++++----- rig-core/src/providers/cohere/completion.rs | 40 +++- rig-core/src/providers/cohere/embeddings.rs | 42 ++-- rig-core/src/providers/cohere/streaming.rs | 8 +- 17 files changed, 772 insertions(+), 218 deletions(-) create mode 100644 rig-core/src/completion/chat.rs create mode 100644 rig-core/src/http_client.rs diff --git a/Cargo.lock b/Cargo.lock index 894da6ccf..037ac1f93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9033,6 +9033,7 @@ dependencies = [ "epub", "futures", "glob", + "http 1.3.1", "hyper-util", "lopdf", "mime_guess", diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 6adcd0e7c..97ef558e0 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -43,6 +43,7 @@ worker = { workspace = true, optional = true } rmcp = { version = "0.6", optional = true, features = ["client"] } reqwest-eventsource = { workspace = true } tokio = { workspace = true, features = ["sync"] } +http = "1.3.1" [dev-dependencies] anyhow = { workspace = true } diff --git a/rig-core/src/client/mod.rs b/rig-core/src/client/mod.rs index a3a45fb5c..22a4f19ef 100644 --- a/rig-core/src/client/mod.rs +++ b/rig-core/src/client/mod.rs @@ -169,33 +169,38 @@ impl AsAudioGeneration for T {} #[cfg(not(feature = "image"))] impl AsImageGeneration for T {} -/// Implements the conversion traits for a given struct -/// ```rust -/// pub struct Client; -/// impl ProviderClient for Client { -/// ... -/// } -/// impl_conversion_traits!(AsCompletion, AsEmbeddings for Client); -/// ``` #[macro_export] macro_rules! impl_conversion_traits { - ($( $trait_:ident ),* for $struct_:ident ) => { - $( - impl_conversion_traits!(@impl $trait_ for $struct_); - )* + ($( $trait_:ident ),* for $($type_spec:tt)+) => { + impl_conversion_traits!(@expand_traits [$($trait_)+] $($type_spec)+); }; - (@impl AsAudioGeneration for $struct_:ident ) => { - rig::client::impl_audio_generation!($struct_); + (@expand_traits [$trait_:ident $($rest_traits:ident)*] $($type_spec:tt)+) => { + impl_conversion_traits!(@impl $trait_ for $($type_spec)+); + impl_conversion_traits!(@expand_traits [$($rest_traits)*] $($type_spec)+); }; - (@impl AsImageGeneration for $struct_:ident ) => { - rig::client::impl_image_generation!($struct_); + (@expand_traits [] $($type_spec:tt)+) => {}; + + (@impl AsAudioGeneration for $($type_spec:tt)+) => { + rig::client::impl_audio_generation!($($type_spec)+); + }; + + (@impl AsImageGeneration for $($type_spec:tt)+) => { + rig::client::impl_image_generation!($($type_spec)+); + }; + + (@impl $trait_:ident for $($type_spec:tt)+) => { + impl_conversion_traits!(@impl_trait $trait_ for $($type_spec)+); }; - (@impl $trait_:ident for $struct_:ident) => { + (@impl_trait $trait_:ident for $struct_:ident) => { impl rig::client::$trait_ for $struct_ {} }; + + (@impl_trait $trait_:ident for $struct_:ident<$($generics:tt),*>) => { + impl<$($generics),*> rig::client::$trait_ for $struct_<$($generics),*> {} + }; } #[cfg(feature = "audio")] @@ -204,12 +209,15 @@ macro_rules! impl_audio_generation { ($struct_:ident) => { impl rig::client::AsAudioGeneration for $struct_ {} }; + ($struct_:ident<$($generics:tt),*>) => { + impl<$($generics),*> rig::client::AsAudioGeneration for $struct_<$($generics),*> {} + }; } #[cfg(not(feature = "audio"))] #[macro_export] macro_rules! impl_audio_generation { - ($struct_:ident) => {}; + ($($tokens:tt)*) => {}; } #[cfg(feature = "image")] @@ -218,12 +226,15 @@ macro_rules! impl_image_generation { ($struct_:ident) => { impl rig::client::AsImageGeneration for $struct_ {} }; + ($struct_:ident<$($generics:tt),*>) => { + impl<$($generics),*> rig::client::AsImageGeneration for $struct_<$($generics),*> {} + }; } #[cfg(not(feature = "image"))] #[macro_export] macro_rules! impl_image_generation { - ($struct_:ident) => {}; + ($($tokens:tt)*) => {}; } pub use impl_audio_generation; diff --git a/rig-core/src/client/verify.rs b/rig-core/src/client/verify.rs index 0867d5e95..3e64a2880 100644 --- a/rig-core/src/client/verify.rs +++ b/rig-core/src/client/verify.rs @@ -1,4 +1,7 @@ -use crate::client::{AsVerify, ProviderClient}; +use crate::{ + client::{AsVerify, ProviderClient}, + http_client, +}; use futures::future::BoxFuture; use thiserror::Error; @@ -12,7 +15,7 @@ pub enum VerifyError { HttpError( #[from] #[source] - reqwest::Error, + http_client::HttpClientError, ), } diff --git a/rig-core/src/completion/chat.rs b/rig-core/src/completion/chat.rs new file mode 100644 index 000000000..e41226560 --- /dev/null +++ b/rig-core/src/completion/chat.rs @@ -0,0 +1,203 @@ +use crate::{ + agent::{Agent, MultiTurnStreamItem, Text}, + completion::{Chat, CompletionError, CompletionModel, PromptError, Usage}, + message::Message, + streaming::{StreamedAssistantContent, StreamingPrompt}, +}; +use futures::StreamExt; +use std::io::{self, Write}; + +pub struct NoImplProvided; + +pub struct ChatImpl(T); + +pub struct AgentImpl { + agent: Agent, + multi_turn_depth: usize, + show_usage: bool, + usage: Usage, +} + +pub struct ChatBotBuilder(T); + +pub struct ChatBot(T); + +/// Trait to abstract message behavior away from cli_chat/`run` loop +#[allow(private_interfaces)] +trait CliChat { + async fn request(&mut self, prompt: &str, history: Vec) + -> Result; + + fn show_usage(&self) -> bool { + false + } + + fn usage(&self) -> Option { + None + } +} + +impl CliChat for ChatImpl { + async fn request( + &mut self, + prompt: &str, + history: Vec, + ) -> Result { + let res = self.0.chat(prompt, history).await?; + println!("{res}"); + + Ok(res) + } +} + +impl CliChat for AgentImpl { + async fn request( + &mut self, + prompt: &str, + history: Vec, + ) -> Result { + let mut response_stream = self + .agent + .stream_prompt(prompt) + .with_history(history) + .multi_turn(self.multi_turn_depth) + .await; + + let mut acc = String::new(); + + loop { + let Some(chunk) = response_stream.next().await else { + break Ok(acc); + }; + + match chunk { + Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { + text, + }))) => { + print!("{}", text); + acc.push_str(&text); + } + Ok(MultiTurnStreamItem::FinalResponse(final_response)) => { + self.usage = final_response.usage(); + } + Err(e) => { + break Err(PromptError::CompletionError( + CompletionError::ResponseError(e.to_string()), + )); + } + _ => continue, + } + } + } + + fn show_usage(&self) -> bool { + self.show_usage + } + + fn usage(&self) -> Option { + Some(self.usage) + } +} + +impl Default for ChatBotBuilder { + fn default() -> Self { + Self(NoImplProvided) + } +} + +impl ChatBotBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn agent( + self, + agent: Agent, + ) -> ChatBotBuilder> { + ChatBotBuilder(AgentImpl { + agent, + multi_turn_depth: 1, + show_usage: false, + usage: Usage::default(), + }) + } + + pub fn chat(self, chatbot: T) -> ChatBotBuilder> { + ChatBotBuilder(ChatImpl(chatbot)) + } +} + +impl ChatBotBuilder> { + pub fn build(self) -> ChatBot> { + let ChatBotBuilder(chat_impl) = self; + ChatBot(chat_impl) + } +} + +impl ChatBotBuilder> { + pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self { + ChatBotBuilder(AgentImpl { + multi_turn_depth, + ..self.0 + }) + } + + pub fn show_usage(self) -> Self { + ChatBotBuilder(AgentImpl { + show_usage: true, + ..self.0 + }) + } + + pub fn build(self) -> ChatBot> { + ChatBot(self.0) + } +} + +#[allow(private_bounds)] +impl ChatBot { + pub async fn run(mut self) -> Result<(), PromptError> { + let stdin = io::stdin(); + let mut stdout = io::stdout(); + let mut history = vec![]; + + loop { + print!("> "); + stdout.flush().unwrap(); + + let mut input = String::new(); + match stdin.read_line(&mut input) { + Ok(_) => { + let input = input.trim(); + if input == "exit" { + break; + } + + tracing::info!("Prompt:\n{input}\n"); + + println!(); + println!("========================== Response ============================"); + + let response = self.0.request(input, history.clone()).await?; + history.push(Message::user(input)); + history.push(Message::assistant(response)); + + println!("================================================================"); + println!(); + + if self.0.show_usage() { + let Usage { + input_tokens, + output_tokens, + .. + } = self.0.usage().unwrap(); + println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens"); + } + } + Err(e) => println!("Error reading request: {e}"), + } + } + + Ok(()) + } +} diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 9c68f42a4..f1f9cc26b 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -66,7 +66,7 @@ use super::message::{AssistantContent, DocumentMediaType}; use crate::client::completion::CompletionModelHandle; use crate::streaming::StreamingCompletionResponse; -use crate::{OneOrMany, streaming}; +use crate::{OneOrMany, http_client, streaming}; use crate::{ json_utils, message::{Message, UserContent}, @@ -85,7 +85,7 @@ use thiserror::Error; pub enum CompletionError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), + HttpError(#[from] http_client::HttpClientError), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index b18d4a9a5..3b5518380 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; pub enum EmbeddingError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), + HttpError(Box), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs new file mode 100644 index 000000000..88f75bccc --- /dev/null +++ b/rig-core/src/http_client.rs @@ -0,0 +1,147 @@ +use bytes::Bytes; +use futures::stream::{BoxStream, StreamExt}; +pub use http::{HeaderValue, Method, Request, Response, Uri, request::Builder}; +use std::future::Future; +use std::pin::Pin; + +#[derive(Debug, thiserror::Error)] +pub enum HttpClientError { + #[error("Http error: {0}")] + Protocol(#[from] http::Error), + #[error("Http client error: {0}")] + Instance(#[from] Box), +} + +fn instance_error(error: E) -> HttpClientError { + HttpClientError::Instance(error.into()) +} + +pub type LazyBytes = Pin> + Send + 'static>>; +pub type LazyBody = Pin> + Send + 'static>>; + +pub type ByteStream = BoxStream<'static, Result>; +pub type StreamingResponse = Response; + +pub struct NoBody; + +impl From for Bytes { + fn from(_: NoBody) -> Self { + Bytes::new() + } +} + +pub trait HttpClientExt: Send + Sync { + fn request( + &self, + req: Request, + ) -> impl Future>, HttpClientError>> + Send + where + T: Into, + U: From + Send; + + fn request_streaming( + &self, + req: Request, + ) -> impl Future> + Send + where + T: Into; + + async fn get(&self, uri: Uri) -> Result>, HttpClientError> + where + T: From + Send, + { + let req = Request::builder() + .method(Method::GET) + .uri(uri) + .body(NoBody)?; + + self.request(req).await + } + + async fn post( + &self, + uri: Uri, + body: T, + ) -> Result>, HttpClientError> + where + U: TryInto, + >::Error: Into, + T: Into, + V: From + Send, + { + let req = Request::builder() + .method(Method::POST) + .uri(uri) + .body(body.into())?; + + self.request(req).await + } +} + +impl HttpClientExt for reqwest::Client { + fn request( + &self, + req: Request, + ) -> impl Future>, HttpClientError>> + Send + where + T: Into, + U: From + Send, + { + let (parts, body) = req.into_parts(); + let req = self + .request(parts.method, parts.uri.to_string()) + .headers(parts.headers) + .body(body.into()); + + async move { + let response = req.send().await.map_err(instance_error)?; + + let mut res = Response::builder() + .status(response.status()) + .version(response.version()); + + if let Some(hs) = res.headers_mut() { + *hs = response.headers().clone(); + } + + let body: LazyBody = Box::pin(async move { + let bytes = response.bytes().await.map_err(instance_error)?; + let body = U::from(bytes); + Ok(body) + }); + + res.body(body).map_err(HttpClientError::Protocol) + } + } + + fn request_streaming( + &self, + req: Request, + ) -> impl Future> + Send + where + T: Into, + { + let (parts, body) = req.into_parts(); + let req = self + .request(parts.method, parts.uri.to_string()) + .headers(parts.headers) + .body(body.into()); + + async move { + let response: reqwest::Response = req.send().await.map_err(instance_error)?; + + let mut res = Response::builder() + .status(response.status()) + .version(response.version()); + + if let Some(hs) = res.headers_mut() { + *hs = response.headers().clone(); + } + + let stream: ByteStream = + Box::pin(response.bytes_stream().map(|r| r.map_err(instance_error))); + + Ok(res.body(stream)?) + } + } +} diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 4f23304e6..c05eeb45b 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -115,6 +115,7 @@ pub mod client; pub mod completion; pub mod embeddings; pub mod extractor; +pub mod http_client; #[cfg(feature = "image")] #[cfg_attr(docsrs, doc(cfg(feature = "image")))] pub mod image_generation; diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs index fc072b228..7c9898c47 100644 --- a/rig-core/src/providers/anthropic/client.rs +++ b/rig-core/src/providers/anthropic/client.rs @@ -1,8 +1,14 @@ //! Anthropic client api implementation +use bytes::Bytes; +use http_client::{Method, Request, Uri}; + use super::completion::{ANTHROPIC_VERSION_LATEST, CompletionModel}; -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, ProviderValue, VerifyClient, VerifyError, - impl_conversion_traits, +use crate::{ + client::{ + ClientBuilderError, CompletionClient, ProviderClient, ProviderValue, VerifyClient, + VerifyError, impl_conversion_traits, + }, + http_client::{self, HttpClientError, HttpClientExt}, }; // ================================================================ @@ -10,12 +16,12 @@ use crate::client::{ // ================================================================ const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, anthropic_version: &'a str, anthropic_betas: Option>, - http_client: Option, + http_client: T, } /// Create a new anthropic client using the builder @@ -30,14 +36,27 @@ pub struct ClientBuilder<'a> { /// .anthropic_beta("prompt-caching-2024-07-31") /// .build() /// ``` -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: HttpClientExt + Default, +{ pub fn new(api_key: &'a str) -> Self { + ClientBuilder { + api_key, + base_url: ANTHROPIC_API_BASE_URL, + anthropic_version: ANTHROPIC_VERSION_LATEST, + anthropic_betas: None, + http_client: Default::default(), + } + } + + pub fn with_client(api_key: &'a str, http_client: T) -> Self { Self { api_key, base_url: ANTHROPIC_API_BASE_URL, anthropic_version: ANTHROPIC_VERSION_LATEST, anthropic_betas: None, - http_client: None, + http_client, } } @@ -61,12 +80,7 @@ impl<'a> ClientBuilder<'a> { self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self - } - - pub fn build(self) -> Result { + pub fn build(self) -> Result, ClientBuilderError> { let mut default_headers = reqwest::header::HeaderMap::new(); default_headers.insert( "anthropic-version", @@ -74,6 +88,7 @@ impl<'a> ClientBuilder<'a> { .parse() .map_err(|_| ClientBuilderError::InvalidProperty("anthropic-version"))?, ); + if let Some(betas) = self.anthropic_betas { default_headers.insert( "anthropic-beta", @@ -84,34 +99,31 @@ impl<'a> ClientBuilder<'a> { ); }; - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - Ok(Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), default_headers, - http_client, + http_client: self.http_client, }) } } #[derive(Clone)] -pub struct Client { +pub struct Client { /// The base URL base_url: String, /// The API key api_key: String, /// The underlying HTTP client - http_client: reqwest::Client, + http_client: T, /// Default headers that will be automatically added to any given request with this client (API key, Anthropic Version and any betas that have been added) default_headers: reqwest::header::HeaderMap, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: HttpClientExt + std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -122,56 +134,102 @@ impl std::fmt::Debug for Client { } } -impl Client { - /// Create a new Anthropic client builder. - /// - /// # Example - /// ``` - /// use rig::providers::anthropic::{ClientBuilder, self}; - /// - /// // Initialize the Anthropic client - /// let anthropic_client = Client::builder("your-claude-api-key") - /// .anthropic_version(ANTHROPIC_VERSION_LATEST) - /// .anthropic_beta("prompt-caching-2024-07-31") - /// .build() - /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { - ClientBuilder::new(api_key) - } +fn build_uri(path: &str) -> Result { + Uri::builder() + .scheme("https") + .authority("api.anthropic.com") + .path_and_query(path) + .build() +} +impl Client +where + T: HttpClientExt + Clone + Default, +{ /// Create a new Anthropic client. For more control, use the `builder` method. /// /// # Panics /// - If the API key or version cannot be parsed as a Json value from a String. /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) + ClientBuilder::new(api_key) .build() .expect("Anthropic client should build") } - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client - .post(url) - .header("X-Api-Key", &self.api_key) - .headers(self.default_headers.clone()) + pub async fn send( + &self, + req: http_client::Request, + ) -> Result>, http_client::HttpClientError> + where + U: Into, + V: From + Send, + { + self.http_client.request(req).await } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client - .get(url) - .header("X-Api-Key", &self.api_key) - .headers(self.default_headers.clone()) + pub async fn send_streaming( + &self, + req: Request, + ) -> Result + where + U: Into, + { + self.http_client.request_streaming(req).await + } + + pub(crate) fn post(&self, path: &str) -> http_client::Builder { + let uri = format!("{}/{}", ANTHROPIC_API_BASE_URL, path).replace("//", "/"); + + let mut headers = self.default_headers.clone(); + + headers.insert( + "X-Api-Key", + http_client::HeaderValue::from_str(&self.api_key).unwrap(), + ); + + let mut req = http_client::Request::builder() + .method(Method::POST) + .uri(uri); + + if let Some(hs) = req.headers_mut() { + *hs = headers; + } + + req + } + + pub(crate) fn get( + &self, + path: &str, + ) -> Result, http::Error> { + let uri = format!("{}/{}", self.base_url, path).replace("//", "/"); + + let mut headers = self.default_headers.clone(); + headers.insert( + "X-Api-Key", + http_client::HeaderValue::from_str(&self.api_key).unwrap(), + ); + + let mut req = http_client::Request::builder().method(Method::GET).uri(uri); + + if let Some(hs) = req.headers_mut() { + *hs = headers; + } + + req.body(http_client::NoBody) } } -impl ProviderClient for Client { +impl ProviderClient for Client +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static, +{ /// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"); + Client::new(&api_key) } @@ -179,35 +237,60 @@ impl ProviderClient for Client { let ProviderValue::Simple(api_key) = input else { panic!("Incorrect provider value type") }; + Client::new(&api_key) } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; - fn completion_model(&self, model: &str) -> CompletionModel { +impl CompletionClient for Client +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static, +{ + type CompletionModel = CompletionModel; + + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static, +{ #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/v1/models").send().await?; + let response: http_client::Response>> = self + .http_client + .request( + self.get("/v1/models") + .map_err(|e| http_client::HttpClientError::Protocol(e))?, + ) + .await?; + match response.status() { - reqwest::StatusCode::OK => Ok(()), - reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => { + http::StatusCode::OK => Ok(()), + http::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => { Err(VerifyError::InvalidAuthentication) } - reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + http::StatusCode::INTERNAL_SERVER_ERROR => { + let text = String::from_utf8_lossy(&response.into_body().await?).into(); + Err(VerifyError::ProviderError(text)) } status if status.as_u16() == 529 => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = String::from_utf8_lossy(&response.into_body().await?).into(); + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; - Ok(()) + let status = response.status(); + + if status.is_success() { + Ok(()) + } else { + let text: String = String::from_utf8_lossy(&response.into_body().await?).into(); + Err(VerifyError::HttpError(HttpClientError::Instance( + format!("Failed with '{status}': {text}").into(), + ))) + } } } } @@ -217,5 +300,6 @@ impl_conversion_traits!( AsTranscription, AsEmbeddings, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration + for Client ); diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index 650db5f61..2da946021 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -3,6 +3,7 @@ use crate::{ OneOrMany, completion::{self, CompletionError}, + http_client::HttpClientExt, json_utils, message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning}, one_or_many::string_or_one_or_many, @@ -12,6 +13,7 @@ use std::{convert::Infallible, str::FromStr}; use super::client::Client; use crate::completion::CompletionRequest; use crate::providers::anthropic::streaming::StreamingCompletionResponse; +use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -576,14 +578,17 @@ impl TryFrom for message::Message { } #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, pub model: String, pub default_max_tokens: Option, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel +where + T: HttpClientExt, +{ + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -630,7 +635,10 @@ pub enum ToolChoice { }, } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel +where + T: HttpClientExt + Clone + Default, +{ type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -700,26 +708,52 @@ impl completion::CompletionModel for CompletionModel { tracing::debug!("Anthropic completion request: {request}"); - let response = self + let request: Vec = serde_json::to_vec(&request)?; + + let req = self .client .post("/v1/messages") - .json(&request) - .send() - .await?; + .body(request) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self + .client + .send::<_, Bytes>(req) + .await + .map_err(|e| CompletionError::HttpError(e.into()))?; if response.status().is_success() { - match response.json::>().await? { + match serde_json::from_slice::>( + response + .into_body() + .await + .map_err(|e| CompletionError::HttpError(e.into()))? + .to_vec() + .as_slice(), + )? { ApiResponse::Message(completion) => { - tracing::info!(target: "rig", - "Anthropic completion token usage: {}", - completion.usage + let completion: Result, _> = + completion.try_into(); + + tracing::info!( + target: "rig", + "Anthropic completion token usage: {:?}", + completion ); - completion.try_into() + + completion } ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + let text: String = String::from_utf8_lossy( + &response + .into_body() + .await + .map_err(|e| CompletionError::HttpError(e.into()))?, + ) + .into(); + Err(CompletionError::ProviderError(text)) } } diff --git a/rig-core/src/providers/anthropic/decoders/sse.rs b/rig-core/src/providers/anthropic/decoders/sse.rs index d73f1b003..fdbbe5c02 100644 --- a/rig-core/src/providers/anthropic/decoders/sse.rs +++ b/rig-core/src/providers/anthropic/decoders/sse.rs @@ -1,5 +1,6 @@ use super::line::{self, LineDecoder}; -use futures::{Stream, StreamExt}; +use bytes::Bytes; +use futures::{Stream, StreamExt, stream::BoxStream}; use std::fmt::Debug; use thiserror::Error; @@ -181,14 +182,14 @@ fn extract_sse_chunk(buffer: &[u8]) -> Option<(Vec, Vec)> { Some((chunk, remaining)) } -pub fn from_response( - response: reqwest::Response, -) -> impl Stream> { - let stream = response.bytes_stream().map(|result| { - result - .map_err(std::io::Error::other) - .map(|bytes| bytes.to_vec()) - }); - - iter_sse_messages(stream) +pub fn from_response<'a, E>( + stream: BoxStream<'a, Result>, +) -> impl Stream> +where + E: Into>, +{ + iter_sse_messages(stream.map(|result| match result { + Ok(bytes) => Ok(bytes.to_vec()), + Err(e) => Err(std::io::Error::other(e)), + })) } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index 8dca93d8f..fa16272b8 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -6,6 +6,7 @@ use serde_json::json; use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage}; use super::decoders::sse::from_response as sse_from_response; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge_inplace; use crate::streaming; use crate::streaming::{RawStreamingChoice, StreamingResult}; @@ -92,7 +93,10 @@ impl GetTokenUsage for StreamingCompletionResponse { } } -impl CompletionModel { +impl CompletionModel +where + T: HttpClientExt + Clone + Default, +{ pub(crate) async fn stream( &self, completion_request: CompletionRequest, @@ -119,7 +123,7 @@ impl CompletionModel { .map(Message::try_from) .collect::, _>>()?; - let mut request = json!({ + let mut body = json!({ "model": self.model, "messages": full_history, "max_tokens": max_tokens, @@ -128,12 +132,12 @@ impl CompletionModel { }); if let Some(temperature) = completion_request.temperature { - merge_inplace(&mut request, json!({ "temperature": temperature })); + merge_inplace(&mut body, json!({ "temperature": temperature })); } if !completion_request.tools.is_empty() { merge_inplace( - &mut request, + &mut body, json!({ "tools": completion_request .tools @@ -150,26 +154,42 @@ impl CompletionModel { } if let Some(ref params) = completion_request.additional_params { - merge_inplace(&mut request, params.clone()) + merge_inplace(&mut body, params.clone()) } - let response = self + let body: Vec = serde_json::to_vec(&body)?; + + let req = self .client .post("/v1/messages") - .json(&request) - .send() - .await?; + .body(body) + .map_err(|e| http_client::HttpClientError::Protocol(e))?; + + let response: http_client::StreamingResponse = self.client.send_streaming(req).await?; if !response.status().is_success() { - return Err(CompletionError::ProviderError(response.text().await?)); + let mut stream = response.into_body(); + let mut text = String::with_capacity(1024); + loop { + let Some(chunk) = stream.next().await else { + break; + }; + + let chunk = chunk?; + + let str = String::from_utf8_lossy(&chunk); + + text.push_str(&str) + } + return Err(CompletionError::ProviderError(text)); } - // Use our SSE decoder to directly handle Server-Sent Events format - let sse_stream = sse_from_response(response); + let stream = sse_from_response(response.into_body()); + // Use our SSE decoder to directly handle Server-Sent Events format let stream: StreamingResult = Box::pin(stream! { let mut current_tool_call: Option = None; - let mut sse_stream = Box::pin(sse_stream); + let mut sse_stream = Box::pin(stream); let mut input_tokens = 0; while let Some(sse_result) = sse_stream.next().await { @@ -184,7 +204,6 @@ impl CompletionModel { }, StreamingEvent::MessageDelta { delta, usage } => { if delta.stop_reason.is_some() { - yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { usage: PartialUsage { output_tokens: usage.output_tokens, @@ -253,7 +272,7 @@ fn handle_event( _ => None, }, StreamingEvent::ContentBlockStop { .. } => { - if let Some(tool_call) = current_tool_call.take() { + if let Some(tool_call) = Option::take(current_tool_call) { let json_str = if tool_call.input_json.is_empty() { "{}" } else { diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index 97b0dc918..a8dff1091 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -2,13 +2,13 @@ use crate::{ Embed, client::{VerifyClient, VerifyError}, embeddings::EmbeddingsBuilder, + http_client::{self, HttpClientExt}, }; use super::{CompletionModel, EmbeddingModel}; -use crate::client::{ - ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits, -}; +use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits}; use serde::Deserialize; +use url::ParseError; #[derive(Debug, Deserialize)] pub struct ApiErrorResponse { @@ -27,54 +27,60 @@ pub enum ApiResponse { // ================================================================ const COHERE_API_BASE_URL: &str = "https://api.cohere.ai"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> +where + T: HttpClientExt, +{ api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { - pub fn new(api_key: &'a str) -> Self { - Self { +impl<'a, T> ClientBuilder<'a, T> +where + T: HttpClientExt, +{ + pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> { + ClientBuilder { api_key, base_url: COHERE_API_BASE_URL, - http_client: None, + http_client: reqwest::Client::new(), } } - pub fn base_url(mut self, base_url: &'a str) -> Self { - self.base_url = base_url; - self + pub fn with_client(api_key: &'a str, http_client: T) -> Self { + ClientBuilder { + api_key, + base_url: COHERE_API_BASE_URL, + http_client, + } } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; self } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: HttpClientExt + std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -84,52 +90,49 @@ impl std::fmt::Debug for Client { } } -impl Client { - /// Create a new Cohere client builder. - /// - /// # Example - /// ``` - /// use rig::providers::cohere::{ClientBuilder, self}; - /// - /// // Initialize the Cohere client - /// let cohere_client = Client::builder("your-cohere-api-key") - /// .build() - /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { - ClientBuilder::new(api_key) - } - +impl Client +where + T: HttpClientExt + Clone, +{ /// Create a new Cohere client. For more control, use the `builder` method. /// /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). - pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Cohere client should build") + pub fn new(api_key: &str) -> Client { + ClientBuilder::with_client(api_key, reqwest::Client::new()).build() } - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn post(&self, path: &str) -> Result, ParseError> + where + U: From + Send, + { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + Ok(http_client::Request::post(url.as_str())?.bearer_auth(self.api_key.as_str())) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + Ok(http_client::Request::get(url.as_str())?.bearer_auth(self.api_key.as_str())) + } + + pub(crate) async fn send( + &self, + req: http_client::Request, + ) -> Result::Error> { + self.http_client.request(req).await } pub fn embeddings( &self, model: &str, input_type: &str, - ) -> EmbeddingsBuilder { + ) -> EmbeddingsBuilder, D> { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } /// Note: default embedding dimension of 0 will be used if model is not known. /// If this is the case, it's better to use function `embedding_model_with_ndims` - pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel { + pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel { let ndims = match model { super::EMBED_ENGLISH_V3 | super::EMBED_MULTILINGUAL_V3 @@ -148,12 +151,19 @@ impl Client { model: &str, input_type: &str, ndims: usize, - ) -> EmbeddingModel { + ) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, input_type, ndims) } } -impl ProviderClient for Client { +impl Client { + pub(crate) async fn eventsource(&self, req: http_client::Request) -> _ { + let req: reqwest::Request = req.into(); + todo!() + } +} + +impl ProviderClient for Client { /// Create a new Cohere client from the `COHERE_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -169,16 +179,16 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; fn completion_model(&self, model: &str) -> Self::CompletionModel { CompletionModel::new(self.clone(), model) } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; fn embedding_model(&self, model: &str) -> Self::EmbeddingModel { self.embedding_model(model, "search_document") @@ -193,10 +203,11 @@ impl EmbeddingsClient for Client { } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/v1/models").send().await?; + let response = self.http_client.get("/v1/models").send().await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), @@ -214,5 +225,5 @@ impl VerifyClient for Client { impl_conversion_traits!( AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index 1170ee4f6..062130fb7 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -1,6 +1,7 @@ use crate::{ OneOrMany, completion::{self, CompletionError}, + http_client::HttpClientExt, json_utils, message::{self, Reasoning}, }; @@ -455,13 +456,16 @@ impl TryFrom for message::Message { } #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel +where + T: HttpClientExt, +{ + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -513,7 +517,10 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel +where + T: HttpClientExt + Clone, +{ type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -528,18 +535,31 @@ impl completion::CompletionModel for CompletionModel { serde_json::to_string_pretty(&request)? ); - let response = self.client.post("/v2/chat").json(&request).send().await?; + let req = self + .client + .post("/v2/chat") + .map_err(|e| CompletionError::HttpError(e.into()))? + .with_json(&request); + + let response = self + .client + .send(req) + .await + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let text = response + .text() + .map_err(|e| CompletionError::ResponseError(e.to_string()))?; if response.status().is_success() { - let text_response = response.text().await?; - tracing::debug!("Cohere response text: {}", text_response); + tracing::debug!("Cohere response text: {}", text); - let json_response: CompletionResponse = serde_json::from_str(&text_response)?; + let json_response: CompletionResponse = serde_json::from_str(&text)?; let completion: completion::CompletionResponse = json_response.try_into()?; Ok(completion) } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError(text.to_string())) } } diff --git a/rig-core/src/providers/cohere/embeddings.rs b/rig-core/src/providers/cohere/embeddings.rs index 9a925ae72..12ef5a4d7 100644 --- a/rig-core/src/providers/cohere/embeddings.rs +++ b/rig-core/src/providers/cohere/embeddings.rs @@ -1,6 +1,9 @@ use super::{Client, client::ApiResponse}; -use crate::embeddings::{self, EmbeddingError}; +use crate::{ + embeddings::{self, EmbeddingError}, + http_client::HttpClientExt, +}; use serde::Deserialize; use serde_json::json; @@ -56,14 +59,17 @@ impl std::fmt::Display for BilledUnits { } #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, pub model: String, pub input_type: String, ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Clone, +{ const MAX_DOCUMENTS: usize = 96; fn ndims(&self) -> usize { @@ -77,19 +83,24 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ) -> Result, EmbeddingError> { let documents = documents.into_iter().collect::>(); - let response = self + let req = self .client .post("/v1/embed") - .json(&json!({ + .map_err(|e| EmbeddingError::HttpError(e.into()))? + .with_json(&json!({ "model": self.model, "texts": documents, "input_type": self.input_type, - })) - .send() - .await?; + })); + + let response = self + .client + .send(req) + .await + .map_err(|e| EmbeddingError::HttpError(e.into()))?; if response.status().is_success() { - match response.json::>().await? { + match response.json::>()? { ApiResponse::Ok(response) => { match response.meta { Some(meta) => tracing::info!(target: "rig", @@ -125,13 +136,18 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + Err(EmbeddingError::ProviderError( + response + .text() + .map_err(|e| EmbeddingError::HttpError(e.into()))? + .to_string(), + )) } } } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index 81c3e8451..18e6cced0 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -1,4 +1,5 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; +use crate::http_client::HttpClientExt; use crate::providers::cohere::CompletionModel; use crate::providers::cohere::completion::Usage; use crate::streaming::RawStreamingChoice; @@ -84,7 +85,8 @@ impl GetTokenUsage for StreamingCompletionResponse { } } -impl CompletionModel { +impl CompletionModel +{ pub(crate) async fn stream( &self, request: CompletionRequest, @@ -98,10 +100,10 @@ impl CompletionModel { serde_json::to_string_pretty(&request)? ); + let req = self.client.post("/v2/chat").map_err(|e| CompletionError::HttpError(e.into()))?.with_json(&request) + let mut event_source = self .client - .post("/v2/chat") - .json(&request) .eventsource() .map_err(|e| CompletionError::ProviderError(e.to_string()))?; From 5e09efbd4015d20900a74a1b3fbd29d9ef7b96e7 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Mon, 29 Sep 2025 16:23:25 -0400 Subject: [PATCH 02/20] Cohere, gemini, huggingface clients --- rig-core/src/completion/request.rs | 2 +- rig-core/src/embeddings/embedding.rs | 4 +- rig-core/src/http_client.rs | 41 ++-- rig-core/src/image_generation.rs | 4 +- rig-core/src/providers/anthropic/client.rs | 10 +- rig-core/src/providers/anthropic/streaming.rs | 2 +- rig-core/src/providers/cohere/client.rs | 41 ++-- rig-core/src/providers/cohere/completion.rs | 33 ++-- rig-core/src/providers/cohere/embeddings.rs | 30 +-- rig-core/src/providers/cohere/streaming.rs | 11 +- rig-core/src/providers/gemini/client.rs | 182 ++++++++++++------ rig-core/src/providers/gemini/completion.rs | 57 ++++-- rig-core/src/providers/gemini/embedding.rs | 33 ++-- rig-core/src/providers/gemini/streaming.rs | 2 +- .../src/providers/gemini/transcription.rs | 37 ++-- rig-core/src/providers/huggingface/client.rs | 150 ++++++++++----- .../src/providers/huggingface/completion.rs | 33 +++- .../providers/huggingface/image_generation.rs | 31 ++- .../src/providers/huggingface/streaming.rs | 6 +- .../providers/huggingface/transcription.rs | 37 ++-- rig-core/src/transcription.rs | 4 +- 21 files changed, 494 insertions(+), 256 deletions(-) diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index f1f9cc26b..52d79ba1c 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -85,7 +85,7 @@ use thiserror::Error; pub enum CompletionError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(#[from] http_client::HttpClientError), + HttpError(#[from] http_client::Error), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 3b5518380..df8cafc87 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -9,11 +9,13 @@ use futures::future::BoxFuture; use serde::{Deserialize, Serialize}; +use crate::http_client; + #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(Box), + HttpError(#[from] http_client::Error), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 88f75bccc..74a8de967 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -1,25 +1,28 @@ use bytes::Bytes; use futures::stream::{BoxStream, StreamExt}; pub use http::{HeaderValue, Method, Request, Response, Uri, request::Builder}; +use reqwest::Body; use std::future::Future; use std::pin::Pin; #[derive(Debug, thiserror::Error)] -pub enum HttpClientError { +pub enum Error { #[error("Http error: {0}")] Protocol(#[from] http::Error), #[error("Http client error: {0}")] Instance(#[from] Box), } -fn instance_error(error: E) -> HttpClientError { - HttpClientError::Instance(error.into()) +pub type Result = std::result::Result; + +fn instance_error(error: E) -> Error { + Error::Instance(error.into()) } -pub type LazyBytes = Pin> + Send + 'static>>; -pub type LazyBody = Pin> + Send + 'static>>; +pub type LazyBytes = Pin> + Send + 'static>>; +pub type LazyBody = Pin> + Send + 'static>>; -pub type ByteStream = BoxStream<'static, Result>; +pub type ByteStream = BoxStream<'static, Result>; pub type StreamingResponse = Response; pub struct NoBody; @@ -30,11 +33,17 @@ impl From for Bytes { } } +impl From for Body { + fn from(_: NoBody) -> Self { + reqwest::Body::default() + } +} + pub trait HttpClientExt: Send + Sync { fn request( &self, req: Request, - ) -> impl Future>, HttpClientError>> + Send + ) -> impl Future>>> + Send where T: Into, U: From + Send; @@ -42,11 +51,11 @@ pub trait HttpClientExt: Send + Sync { fn request_streaming( &self, req: Request, - ) -> impl Future> + Send + ) -> impl Future> + Send where T: Into; - async fn get(&self, uri: Uri) -> Result>, HttpClientError> + async fn get(&self, uri: Uri) -> Result>> where T: From + Send, { @@ -58,14 +67,10 @@ pub trait HttpClientExt: Send + Sync { self.request(req).await } - async fn post( - &self, - uri: Uri, - body: T, - ) -> Result>, HttpClientError> + async fn post(&self, uri: Uri, body: T) -> Result>> where U: TryInto, - >::Error: Into, + >::Error: Into, T: Into, V: From + Send, { @@ -82,7 +87,7 @@ impl HttpClientExt for reqwest::Client { fn request( &self, req: Request, - ) -> impl Future>, HttpClientError>> + Send + ) -> impl Future>>> + Send where T: Into, U: From + Send, @@ -110,14 +115,14 @@ impl HttpClientExt for reqwest::Client { Ok(body) }); - res.body(body).map_err(HttpClientError::Protocol) + res.body(body).map_err(Error::Protocol) } } fn request_streaming( &self, req: Request, - ) -> impl Future> + Send + ) -> impl Future> + Send where T: Into, { diff --git a/rig-core/src/image_generation.rs b/rig-core/src/image_generation.rs index e1c0a23a4..8773c0e99 100644 --- a/rig-core/src/image_generation.rs +++ b/rig-core/src/image_generation.rs @@ -1,6 +1,6 @@ //! Everything related to core image generation abstractions in Rig. //! Rig allows calling a number of different providers (that support image generation) using the [ImageGenerationModel] trait. -use crate::client::image_generation::ImageGenerationModelHandle; +use crate::{client::image_generation::ImageGenerationModelHandle, http_client}; use futures::future::BoxFuture; use serde_json::Value; use std::sync::Arc; @@ -10,7 +10,7 @@ use thiserror::Error; pub enum ImageGenerationError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), + HttpError(#[from] http_client::Error), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs index 7c9898c47..4774587d6 100644 --- a/rig-core/src/providers/anthropic/client.rs +++ b/rig-core/src/providers/anthropic/client.rs @@ -8,7 +8,7 @@ use crate::{ ClientBuilderError, CompletionClient, ProviderClient, ProviderValue, VerifyClient, VerifyError, impl_conversion_traits, }, - http_client::{self, HttpClientError, HttpClientExt}, + http_client::{self, HttpClientExt}, }; // ================================================================ @@ -160,7 +160,7 @@ where pub async fn send( &self, req: http_client::Request, - ) -> Result>, http_client::HttpClientError> + ) -> Result>, http_client::Error> where U: Into, V: From + Send, @@ -171,7 +171,7 @@ where pub async fn send_streaming( &self, req: Request, - ) -> Result + ) -> Result where U: Into, { @@ -263,7 +263,7 @@ where .http_client .request( self.get("/v1/models") - .map_err(|e| http_client::HttpClientError::Protocol(e))?, + .map_err(|e| http_client::Error::Protocol(e))?, ) .await?; @@ -287,7 +287,7 @@ where Ok(()) } else { let text: String = String::from_utf8_lossy(&response.into_body().await?).into(); - Err(VerifyError::HttpError(HttpClientError::Instance( + Err(VerifyError::HttpError(http_client::Error::Instance( format!("Failed with '{status}': {text}").into(), ))) } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index fa16272b8..0cb090fd9 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -163,7 +163,7 @@ where .client .post("/v1/messages") .body(body) - .map_err(|e| http_client::HttpClientError::Protocol(e))?; + .map_err(|e| http_client::Error::Protocol(e))?; let response: http_client::StreamingResponse = self.client.send_streaming(req).await?; diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index a8dff1091..d85f43375 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -7,8 +7,9 @@ use crate::{ use super::{CompletionModel, EmbeddingModel}; use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient, impl_conversion_traits}; +use bytes::Bytes; +use reqwest_eventsource::{CannotCloneRequestError, EventSource}; use serde::Deserialize; -use url::ParseError; #[derive(Debug, Deserialize)] pub struct ApiErrorResponse { @@ -102,23 +103,35 @@ where ClientBuilder::with_client(api_key, reqwest::Client::new()).build() } - pub(crate) fn post(&self, path: &str) -> Result, ParseError> + pub(crate) fn post(&self, path: &str) -> http_client::Result where U: From + Send, { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - Ok(http_client::Request::post(url.as_str())?.bearer_auth(self.api_key.as_str())) + let auth_header = + http_client::HeaderValue::try_from(format!("Bearer {}", self.api_key.as_str())) + .map_err(http::Error::from)?; + + Ok(http_client::Request::post(url).header("Authorization", auth_header)) } - pub(crate) fn get(&self, path: &str) -> Result { + pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - Ok(http_client::Request::get(url.as_str())?.bearer_auth(self.api_key.as_str())) + let auth_header = + http_client::HeaderValue::try_from(format!("Bearer {}", self.api_key.as_str())) + .map_err(http::Error::from)?; + + Ok(http_client::Request::get(url).header("Authorization", auth_header)) } - pub(crate) async fn send( + pub(crate) async fn send( &self, - req: http_client::Request, - ) -> Result::Error> { + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + V: From + Send, + { self.http_client.request(req).await } @@ -157,9 +170,15 @@ where } impl Client { - pub(crate) async fn eventsource(&self, req: http_client::Request) -> _ { - let req: reqwest::Request = req.into(); - todo!() + pub(crate) async fn eventsource( + &self, + req: reqwest::RequestBuilder, + ) -> Result { + reqwest_eventsource::EventSource::new(req) + } + + pub(crate) fn client(&self) -> &reqwest::Client { + &self.http_client } } diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index 062130fb7..ac6bf1e32 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -1,7 +1,7 @@ use crate::{ OneOrMany, completion::{self, CompletionError}, - http_client::HttpClientExt, + http_client::{self, HttpClientExt}, json_utils, message::{self, Reasoning}, }; @@ -517,10 +517,7 @@ where } } -impl completion::CompletionModel for CompletionModel -where - T: HttpClientExt + Clone, -{ +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -535,23 +532,21 @@ where serde_json::to_string_pretty(&request)? ); - let req = self - .client - .post("/v2/chat") - .map_err(|e| CompletionError::HttpError(e.into()))? - .with_json(&request); - let response = self .client - .send(req) + .client() + .post("/v2/chat") + .json(&request) + .send() .await - .map_err(|e| CompletionError::HttpError(e.into()))?; - - let text = response - .text() - .map_err(|e| CompletionError::ResponseError(e.to_string()))?; + .map_err(|e| http_client::Error::Instance(e.into()))?; if response.status().is_success() { + let text = response + .text() + .await + .map_err(|e| CompletionError::ResponseError(e.to_string()))?; + tracing::debug!("Cohere response text: {}", text); let json_response: CompletionResponse = serde_json::from_str(&text)?; @@ -559,6 +554,10 @@ where json_response.try_into()?; Ok(completion) } else { + let text = response + .text() + .await + .map_err(|e| CompletionError::ResponseError(e.to_string()))?; Err(CompletionError::ProviderError(text.to_string())) } } diff --git a/rig-core/src/providers/cohere/embeddings.rs b/rig-core/src/providers/cohere/embeddings.rs index 12ef5a4d7..90602ccda 100644 --- a/rig-core/src/providers/cohere/embeddings.rs +++ b/rig-core/src/providers/cohere/embeddings.rs @@ -83,24 +83,32 @@ where ) -> Result, EmbeddingError> { let documents = documents.into_iter().collect::>(); + let body = json!({ + "model": self.model, + "texts": documents, + "input_type": self.input_type + }); + + let body = serde_json::to_vec(&body)?; + let req = self .client .post("/v1/embed") .map_err(|e| EmbeddingError::HttpError(e.into()))? - .with_json(&json!({ - "model": self.model, - "texts": documents, - "input_type": self.input_type, - })); + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; let response = self .client - .send(req) + .send::<_, Vec>(req) .await .map_err(|e| EmbeddingError::HttpError(e.into()))?; if response.status().is_success() { - match response.json::>()? { + let body: ApiResponse = + serde_json::from_slice(response.into_body().await?.as_slice())?; + + match body { ApiResponse::Ok(response) => { match response.meta { Some(meta) => tracing::info!(target: "rig", @@ -136,12 +144,8 @@ where ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)), } } else { - Err(EmbeddingError::ProviderError( - response - .text() - .map_err(|e| EmbeddingError::HttpError(e.into()))? - .to_string(), - )) + let text = String::from_utf8_lossy(&response.into_body().await?).into(); + Err(EmbeddingError::ProviderError(text)) } } } diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index 18e6cced0..f6b6b6f13 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -1,12 +1,11 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; -use crate::http_client::HttpClientExt; use crate::providers::cohere::CompletionModel; use crate::providers::cohere::completion::Usage; use crate::streaming::RawStreamingChoice; use crate::{json_utils, streaming}; use async_stream::stream; use futures::StreamExt; -use reqwest_eventsource::{Event, RequestBuilderExt}; +use reqwest_eventsource::Event; use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize)] @@ -85,8 +84,7 @@ impl GetTokenUsage for StreamingCompletionResponse { } } -impl CompletionModel -{ +impl CompletionModel { pub(crate) async fn stream( &self, request: CompletionRequest, @@ -100,11 +98,12 @@ impl CompletionModel serde_json::to_string_pretty(&request)? ); - let req = self.client.post("/v2/chat").map_err(|e| CompletionError::HttpError(e.into()))?.with_json(&request) + let req = self.client.client().post("/v2/chat").json(&request); let mut event_source = self .client - .eventsource() + .eventsource::(req) + .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; let stream = Box::pin(stream! { diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index fa11ae3ea..b6281a240 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -5,71 +5,90 @@ use crate::client::{ ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient, VerifyError, impl_conversion_traits, }; +use crate::completion; +use crate::http_client::{self, HttpClientExt}; use crate::{ Embed, embeddings::{self}, }; +use bytes::Bytes; use serde::Deserialize; +use std::fmt::Debug; // ================================================================ // Google Gemini Client // ================================================================ const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { - pub fn new(api_key: &'a str) -> Self { +impl<'a, T> ClientBuilder<'a, T> +where + T: HttpClientExt, +{ + pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> { + ClientBuilder { + api_key, + base_url: GEMINI_API_BASE_URL, + http_client: Default::default(), + } + } + + pub fn new_with_client(api_key: &'a str, http_client: T) -> Self { Self { api_key, base_url: GEMINI_API_BASE_URL, - http_client: None, + http_client, } } - pub fn base_url(mut self, base_url: &'a str) -> Self { - self.base_url = base_url; - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> + where + U: HttpClientExt, + { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; self } - pub fn build(self) -> Result { + pub fn build(self) -> Result, ClientBuilderError> { let mut default_headers = reqwest::header::HeaderMap::new(); default_headers.insert( reqwest::header::CONTENT_TYPE, "application/json".parse().unwrap(), ); - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; Ok(Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), default_headers, - http_client, + http_client: self.http_client, }) } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, default_headers: reqwest::header::HeaderMap, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl Debug for Client +where + T: Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -80,7 +99,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: HttpClientExt + Default, +{ /// Create a new Google Gemini client builder. /// /// # Example @@ -91,8 +113,8 @@ impl Client { /// let gemini_client = Client::builder("your-google-gemini-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { - ClientBuilder::new(api_key) + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { + ClientBuilder::new_with_client(api_key, Default::default()) } /// Create a new Google Gemini client. For more control, use the `builder` method. @@ -104,39 +126,69 @@ impl Client { .build() .expect("Gemini client should build") } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { - // API key gets inserted as query param - no need to add bearer auth or headers - let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); +impl Client { + pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder { + let url = + format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/"); + + tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****"); - tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****"); self.http_client .post(url) .headers(self.default_headers.clone()) } +} + +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Builder { + // API key gets inserted as query param - no need to add bearer auth or headers + let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); + + tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****"); + let mut req = http_client::Request::post(url); + + if let Some(hs) = req.headers_mut() { + *hs = self.default_headers.clone(); + } + + req + } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Builder { // API key gets inserted as query param - no need to add bearer auth or headers let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****"); - self.http_client - .get(url) - .headers(self.default_headers.clone()) - } - pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder { - let url = - format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/"); + let mut req = http_client::Request::get(url); - tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****"); - self.http_client - .post(url) - .headers(self.default_headers.clone()) + if let Some(hs) = req.headers_mut() { + *hs = self.default_headers.clone(); + } + + req + } + + pub(crate) async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + R: From + Send, + { + self.http_client.request(req).await } } -impl ProviderClient for Client { +// NOTE: (@FayCarsons) This cannot be implemented for all T because `AsCompletion`/`CompletionModel` requires SSE +// which we are not able to implement for any `T: HttpClientExt` right now +impl ProviderClient for Client { /// Create a new Google Gemini client from the `GEMINI_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -152,19 +204,23 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct. /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> Self::CompletionModel { CompletionModel::new(self.clone(), model) } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client +where + T: HttpClientExt + Clone + Debug + Default + 'static, + Client: CompletionClient, +{ + type EmbeddingModel = EmbeddingModel; /// Create an embedding model with the given name. /// Note: default embedding dimension of 0 will be used if model is not known. @@ -179,7 +235,7 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = gemini.embedding_model(gemini::embedding::EMBEDDING_GECKO_001); /// ``` - fn embedding_model(&self, model: &str) -> EmbeddingModel { + fn embedding_model(&self, model: &str) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, None) } @@ -194,7 +250,7 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = gemini.embedding_model_with_ndims("model-unknown-to-rig", 1024); /// ``` - fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, Some(ndims)) } @@ -217,35 +273,49 @@ impl EmbeddingsClient for Client { fn embeddings( &self, model: &str, - ) -> embeddings::EmbeddingsBuilder { + ) -> embeddings::EmbeddingsBuilder, D> { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } } -impl TranscriptionClient for Client { - type TranscriptionModel = TranscriptionModel; +impl TranscriptionClient for Client +where + T: HttpClientExt + Clone + Debug + Default + 'static, + Client: CompletionClient, +{ + type TranscriptionModel = TranscriptionModel; /// Create a transcription model with the given name. /// Gemini-specific parameters can be set using the [GenerationConfig](crate::providers::gemini::completion::gemini_api_types::GenerationConfig) struct. /// [Gemini API Reference](https://ai.google.dev/api/generate-content#generationconfig) - fn transcription_model(&self, model: &str) -> TranscriptionModel { + fn transcription_model(&self, model: &str) -> TranscriptionModel { TranscriptionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client +where + T: HttpClientExt + Clone + Debug + Default + 'static, + Client: CompletionClient, +{ #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/v1beta/models").send().await?; + let req = self.get("/v1beta/models").body(http_client::NoBody)?; + let response = self.http_client.request::<_, Vec>(req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::SERVICE_UNAVAILABLE => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = String::from_utf8_lossy(&response.into_body().await?).into(); + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + // TODO: Find/write some alternative for this that uses `http::StatusCode` vs + // reqwest::StatusCode + // + // response.error_for_status()?; Ok(()) } } @@ -254,7 +324,7 @@ impl VerifyClient for Client { impl_conversion_traits!( AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 61fc8315d..68a110ac3 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -28,6 +28,7 @@ pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b"; pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; use self::gemini_api_types::Schema; +use crate::http_client::HttpClientExt; use crate::message::Reasoning; use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters; use crate::providers::gemini::streaming::StreamingCompletionResponse; @@ -49,13 +50,13 @@ use super::Client; // ================================================================= #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -63,7 +64,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = GenerateContentResponse; type StreamingResponse = StreamingCompletionResponse; @@ -72,23 +73,31 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request = create_request_body(completion_request)?; + let body = create_request_body(completion_request) + .and_then(|body| serde_json::to_vec(&body).map_err(Into::into))?; tracing::debug!( "Sending completion request to Gemini API {}", - serde_json::to_string_pretty(&request)? + String::from_utf8_lossy(&body) ); - let response = self + let request = self .client .post(&format!("/v1beta/models/{}:generateContent", self.model)) - .json(&request) - .send() - .await?; + .body(body) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self.client.send::<_, Vec>(request).await?; if response.status().is_success() { - let response = response.json::().await?; - match response.usage_metadata { + let response_body = response + .into_body() + .await + .map_err(|e| CompletionError::HttpError(e))?; + + let body: GenerateContentResponse = serde_json::from_slice(&response_body)?; + + match body.usage_metadata { Some(ref usage) => tracing::info!(target: "rig", "Gemini completion token usage: {}", usage @@ -100,10 +109,18 @@ impl completion::CompletionModel for CompletionModel { tracing::debug!("Received response"); - Ok(completion::CompletionResponse::try_from(response)) + Ok(completion::CompletionResponse::try_from(body)?) } else { - Err(CompletionError::ProviderError(response.text().await?)) - }? + let text = String::from_utf8_lossy( + &response + .into_body() + .await + .map_err(|e| CompletionError::HttpError(e.into()))?, + ) + .into(); + + Err(CompletionError::ProviderError(text)) + } } #[cfg_attr(feature = "worker", worker::send)] @@ -585,7 +602,7 @@ pub mod gemini_api_types { file_uri: url, }), DocumentSourceKind::Base64(data) => PartKind::InlineData(Blob { mime_type, data }), - DocumentSourceKind::Unknown => { + _ => { return Err(message::MessageError::ConversionError( "Can't convert an unknown document source".to_string(), )); @@ -678,7 +695,7 @@ pub mod gemini_api_types { DocumentSourceKind::Base64(data) => { PartKind::InlineData(Blob { data, mime_type }) } - DocumentSourceKind::Unknown => { + _ => { return Err(message::MessageError::ConversionError( "Document has no body".to_string(), )); @@ -714,7 +731,7 @@ pub mod gemini_api_types { mime_type: Some(mime_type), file_uri, }), - DocumentSourceKind::Unknown => { + _ => { return Err(message::MessageError::ConversionError( "Content has no body".to_string(), )); @@ -747,7 +764,7 @@ pub mod gemini_api_types { DocumentSourceKind::Base64(data) => { PartKind::InlineData(Blob { mime_type, data }) } - DocumentSourceKind::Unknown => { + _ => { return Err(message::MessageError::ConversionError( "Media type for video is required for Gemini".to_string(), )); diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 16134de0e..f4da9de1a 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -5,7 +5,10 @@ use serde_json::json; -use crate::embeddings::{self, EmbeddingError}; +use crate::{ + embeddings::{self, EmbeddingError}, + http_client::HttpClientExt, +}; use super::{Client, client::ApiResponse}; @@ -14,14 +17,14 @@ pub const EMBEDDING_001: &str = "embedding-001"; /// `text-embedding-004` embedding model pub const EMBEDDING_004: &str = "text-embedding-004"; #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, model: String, ndims: Option, } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: Option) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: Option) -> Self { Self { client, model: model.to_string(), @@ -30,7 +33,10 @@ impl EmbeddingModel { } } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: Send + Sync + Clone + HttpClientExt, +{ const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -68,15 +74,16 @@ impl embeddings::EmbeddingModel for EmbeddingModel { tracing::info!("{}", serde_json::to_string_pretty(&request_body).unwrap()); - let response = self + let request_body = serde_json::to_vec(&request_body)?; + let req = self .client .post(&format!("/v1beta/models/{}:batchEmbedContents", self.model)) - .json(&request_body) - .send() - .await? - .error_for_status()? - .json::>() - .await?; + .body(request_body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + let response = self.client.send::<_, Vec>(req).await?; + + let response: ApiResponse = + serde_json::from_slice(&response.into_body().await?)?; match response { ApiResponse::Ok(response) => { diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 17716022f..395d96ffe 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -53,7 +53,7 @@ impl GetTokenUsage for StreamingCompletionResponse { } } -impl CompletionModel { +impl CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, diff --git a/rig-core/src/providers/gemini/transcription.rs b/rig-core/src/providers/gemini/transcription.rs index 0c601a478..6b67f4f21 100644 --- a/rig-core/src/providers/gemini/transcription.rs +++ b/rig-core/src/providers/gemini/transcription.rs @@ -5,6 +5,7 @@ use mime_guess; use serde_json::{Map, Value}; use crate::{ + http_client::HttpClientExt, providers::gemini::completion::gemini_api_types::{ Blob, Content, GenerateContentRequest, GenerationConfig, Part, PartKind, Role, }, @@ -21,14 +22,14 @@ const TRANSCRIPTION_PREAMBLE: &str = "Translate the provided audio exactly. Do not add additional information."; #[derive(Clone)] -pub struct TranscriptionModel { - client: Client, +pub struct TranscriptionModel { + client: Client, /// Name of the model (e.g.: gemini-1.5-flash) pub model: String, } -impl TranscriptionModel { - pub fn new(client: Client, model: &str) -> Self { +impl TranscriptionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -36,7 +37,10 @@ impl TranscriptionModel { } } -impl transcription::TranscriptionModel for TranscriptionModel { +impl transcription::TranscriptionModel for TranscriptionModel +where + T: HttpClientExt + Send + Sync + Clone, +{ type Response = GenerateContentResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -96,16 +100,20 @@ impl transcription::TranscriptionModel for TranscriptionModel { serde_json::to_string_pretty(&request)? ); - let response = self + let body = serde_json::to_vec(&request)?; + let req = self .client .post(&format!("/v1beta/models/{}:generateContent", self.model)) - .json(&request) - .send() - .await?; + .body(body) + .map_err(|e| TranscriptionError::HttpError(e.into()))?; + + let response = self.client.send::<_, Vec>(req).await?; if response.status().is_success() { - let response = response.json::().await?; - match response.usage_metadata { + let body: GenerateContentResponse = + serde_json::from_slice(&response.into_body().await?)?; + + match body.usage_metadata { Some(ref usage) => tracing::info!(target: "rig", "Gemini completion token usage: {}", usage @@ -117,10 +125,11 @@ impl transcription::TranscriptionModel for TranscriptionModel { tracing::debug!("Received response"); - Ok(transcription::TranscriptionResponse::try_from(response)) + Ok(transcription::TranscriptionResponse::try_from(body)?) } else { - Err(TranscriptionError::ProviderError(response.text().await?)) - }? + let text = String::from_utf8_lossy(&response.into_body().await?).into(); + Err(TranscriptionError::ProviderError(text)) + } } } diff --git a/rig-core/src/providers/huggingface/client.rs b/rig-core/src/providers/huggingface/client.rs index 38b7d037e..e9cced191 100644 --- a/rig-core/src/providers/huggingface/client.rs +++ b/rig-core/src/providers/huggingface/client.rs @@ -5,13 +5,16 @@ use crate::client::{ ClientBuilderError, CompletionClient, ProviderClient, TranscriptionClient, VerifyClient, VerifyError, }; +use crate::http_client::{self, HttpClientExt}; #[cfg(feature = "image")] use crate::image_generation::ImageGenerationError; #[cfg(feature = "image")] use crate::providers::huggingface::image_generation::ImageGenerationModel; use crate::providers::huggingface::transcription::TranscriptionModel; use crate::transcription::TranscriptionError; +use bytes::Bytes; use rig::client::impl_conversion_traits; +use std::fmt::Debug; use std::fmt::Display; // ================================================================ @@ -105,20 +108,34 @@ impl Display for SubProvider { } } -pub struct ClientBuilder { +pub struct ClientBuilder { api_key: String, base_url: String, sub_provider: SubProvider, - http_client: Option, + http_client: T, } -impl ClientBuilder { - pub fn new(api_key: &str) -> Self { - Self { +impl ClientBuilder +where + T: Default, +{ + pub fn new(api_key: &str) -> ClientBuilder { + ClientBuilder { api_key: api_key.to_string(), base_url: HUGGINGFACE_API_BASE_URL.to_string(), sub_provider: SubProvider::default(), - http_client: None, + http_client: Default::default(), + } + } +} + +impl ClientBuilder { + pub fn with_client(self, http_client: U) -> ClientBuilder { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + sub_provider: self.sub_provider, + http_client, } } @@ -132,12 +149,7 @@ impl ClientBuilder { self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self - } - - pub fn build(self) -> Result { + pub fn build(self) -> Result, ClientBuilderError> { let route = self.sub_provider.to_string(); let base_url = format!("{}/{}", self.base_url, route).replace("//", "/"); @@ -148,32 +160,30 @@ impl ClientBuilder { .parse() .expect("Failed to parse Content-Type"), ); - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; Ok(Client { base_url, default_headers, api_key: self.api_key, - http_client, + http_client: self.http_client, sub_provider: self.sub_provider, }) } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, default_headers: reqwest::header::HeaderMap, api_key: String, - http_client: reqwest::Client, + http_client: T, pub(crate) sub_provider: SubProvider, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -185,7 +195,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Huggingface client builder. /// /// # Example @@ -196,7 +209,7 @@ impl Client { /// let client = Client::builder("your-huggingface-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder { + pub fn builder(api_key: &str) -> ClientBuilder { ClientBuilder::new(api_key) } @@ -209,25 +222,70 @@ impl Client { .build() .expect("Huggingface client should build") } +} + +impl Client { + pub(crate) fn client(&self) -> &reqwest::Client { + &self.http_client + } - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client .post(url) - .bearer_auth(&self.api_key) .headers(self.default_headers.clone()) + .bearer_auth(&self.api_key) } +} - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client - .get(url) - .bearer_auth(&self.api_key) - .headers(self.default_headers.clone()) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + let mut req = http_client::Request::post(url).header("Authorization", auth_header); + + if let Some(hs) = req.headers_mut() { + *hs = self.default_headers.clone(); + } + + Ok(req) + } + + pub(crate) fn get(&self, path: &str) -> http_client::Result { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + let mut req = http_client::Request::get(url).header("Authorization", auth_header); + + if let Some(hs) = req.headers_mut() { + *hs = self.default_headers.clone(); + } + + Ok(req) + } + + pub(crate) async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + V: From + Send, + { + self.http_client.request(req).await } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -243,8 +301,8 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a new completion model with the given name /// @@ -257,13 +315,13 @@ impl CompletionClient for Client { /// /// let completion_model = client.completion_model(huggingface::GEMMA_2); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl TranscriptionClient for Client { - type TranscriptionModel = TranscriptionModel; +impl TranscriptionClient for Client { + type TranscriptionModel = TranscriptionModel; /// Create a new transcription model with the given name /// @@ -277,14 +335,14 @@ impl TranscriptionClient for Client { /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3); /// ``` /// - fn transcription_model(&self, model: &str) -> TranscriptionModel { + fn transcription_model(&self, model: &str) -> TranscriptionModel { TranscriptionModel::new(self.clone(), model) } } #[cfg(feature = "image")] -impl ImageGenerationClient for Client { - type ImageGenerationModel = ImageGenerationModel; +impl ImageGenerationClient for Client { + type ImageGenerationModel = ImageGenerationModel; /// Create a new image generation model with the given name /// @@ -297,20 +355,24 @@ impl ImageGenerationClient for Client { /// /// let completion_model = client.image_generation_model(huggingface::WHISPER_LARGE_V3); /// ``` - fn image_generation_model(&self, model: &str) -> ImageGenerationModel { + fn image_generation_model(&self, model: &str) -> ImageGenerationModel { ImageGenerationModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/api/whoami-v2").send().await?; + let req = self.get("/api/whoami-v2")?.body(http_client::NoBody)?; + let req = reqwest::Request::try_from(req)?; + let response: reqwest::Response = self.http_client.execute(req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = response.text().await?; + Err(VerifyError::ProviderError(text)) } _ => { response.error_for_status()?; @@ -320,4 +382,4 @@ impl VerifyClient for Client { } } -impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client); +impl_conversion_traits!(AsEmbeddings, AsAudioGeneration for Client); diff --git a/rig-core/src/providers/huggingface/completion.rs b/rig-core/src/providers/huggingface/completion.rs index fbbcfc2c1..cd4bb2832 100644 --- a/rig-core/src/providers/huggingface/completion.rs +++ b/rig-core/src/providers/huggingface/completion.rs @@ -3,6 +3,7 @@ use serde_json::{Value, json}; use std::{convert::Infallible, str::FromStr}; use super::client::Client; +use crate::http_client::HttpClientExt; use crate::providers::openai::StreamingCompletionResponse; use crate::{ OneOrMany, @@ -496,14 +497,14 @@ impl TryFrom for completion::CompletionResponse { + pub(crate) client: Client, /// Name of the model (e.g: google/gemma-2-2b-it) pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -556,7 +557,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -575,13 +576,23 @@ impl completion::CompletionModel for CompletionModel { request }; - let response = self.client.post(&path).json(&request).send().await?; + let request = serde_json::to_vec(&request)?; + + let request = self + .client + .post(&path)? + .body(request) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self.client.send(request).await?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "Huggingface completion error: {}", t); + let bytes: Vec = response.into_body().await?; + let text = String::from_utf8_lossy(&bytes); + + tracing::debug!(target: "rig", "Huggingface completion error: {}", text); - match serde_json::from_str::>(&t)? { + match serde_json::from_slice::>(&bytes)? { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "Huggingface completion token usage: {:?}", @@ -592,10 +603,12 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())), } } else { + let text: Vec = response.into_body().await?; + let text = String::from_utf8_lossy(&text).into(); Err(CompletionError::ProviderError(format!( "{}: {}", response.status(), - response.text().await? + text ))) } } diff --git a/rig-core/src/providers/huggingface/image_generation.rs b/rig-core/src/providers/huggingface/image_generation.rs index 753119883..31ea98406 100644 --- a/rig-core/src/providers/huggingface/image_generation.rs +++ b/rig-core/src/providers/huggingface/image_generation.rs @@ -1,4 +1,5 @@ use super::Client; +use crate::http_client::HttpClientExt; use crate::image_generation; use crate::image_generation::{ImageGenerationError, ImageGenerationRequest}; use serde_json::json; @@ -26,13 +27,13 @@ impl TryFrom } #[derive(Clone)] -pub struct ImageGenerationModel { - client: Client, +pub struct ImageGenerationModel { + client: Client, pub model: String, } -impl ImageGenerationModel { - pub fn new(client: Client, model: &str) -> Self { +impl ImageGenerationModel { + pub fn new(client: Client, model: &str) -> Self { ImageGenerationModel { client, model: model.to_string(), @@ -40,7 +41,10 @@ impl ImageGenerationModel { } } -impl image_generation::ImageGenerationModel for ImageGenerationModel { +impl image_generation::ImageGenerationModel for ImageGenerationModel +where + T: HttpClientExt + Send + Clone + 'static, +{ type Response = ImageGenerationResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -62,17 +66,28 @@ impl image_generation::ImageGenerationModel for ImageGenerationModel { .sub_provider .image_generation_endpoint(&self.model)?; - let response = self.client.post(&route).json(&request).send().await?; + let body = serde_json::to_vec(&request)?; + + let req = self + .client + .post(&route)? + .body(body) + .map_err(|e| ImageGenerationError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if !response.status().is_success() { + let text: Vec = response.into_body().await?; + let text = String::from_utf8_lossy(&text).into(); + return Err(ImageGenerationError::ProviderError(format!( "{}: {}", response.status(), - response.text().await? + text ))); } - let data = response.bytes().await?.to_vec(); + let data: Vec = response.into_body().await?; ImageGenerationResponse { data }.try_into() } diff --git a/rig-core/src/providers/huggingface/streaming.rs b/rig-core/src/providers/huggingface/streaming.rs index edd405aa6..35ce1764c 100644 --- a/rig-core/src/providers/huggingface/streaming.rs +++ b/rig-core/src/providers/huggingface/streaming.rs @@ -5,7 +5,7 @@ use crate::providers::openai::{StreamingCompletionResponse, send_compatible_stre use crate::streaming; use serde_json::json; -impl CompletionModel { +impl CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, @@ -26,7 +26,9 @@ impl CompletionModel { // HF Inference API uses the model in the path even though its specified in the request body let path = self.client.sub_provider.completion_endpoint(&self.model); - let builder = self.client.post(&path).json(&request); + let request = serde_json::to_vec(&request)?; + + let builder = self.client.post_reqwest(&path).body(request); send_compatible_streaming_request(builder).await } diff --git a/rig-core/src/providers/huggingface/transcription.rs b/rig-core/src/providers/huggingface/transcription.rs index 70f99fdc0..a776d5142 100644 --- a/rig-core/src/providers/huggingface/transcription.rs +++ b/rig-core/src/providers/huggingface/transcription.rs @@ -1,3 +1,4 @@ +use crate::http_client::{self, HttpClientExt}; use crate::providers::huggingface::Client; use crate::providers::huggingface::completion::ApiResponse; use crate::transcription; @@ -30,21 +31,24 @@ impl TryFrom } #[derive(Clone)] -pub struct TranscriptionModel { - client: Client, +pub struct TranscriptionModel { + client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl TranscriptionModel { - pub fn new(client: Client, model: &str) -> Self { +impl TranscriptionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), } } } -impl transcription::TranscriptionModel for TranscriptionModel { +impl transcription::TranscriptionModel for TranscriptionModel +where + T: HttpClientExt + Clone, +{ type Response = TranscriptionResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -63,18 +67,29 @@ impl transcription::TranscriptionModel for TranscriptionModel { .client .sub_provider .transcription_endpoint(&self.model)?; - let response = self.client.post(&route).json(&request).send().await?; + + let request = serde_json::to_vec(&request)?; + + let req = self + .client + .post(&route)? + .body(request) + .map_err(|e| TranscriptionError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - match response - .json::>() - .await? - { + let body: Vec = response.into_body().await?; + let body: ApiResponse = serde_json::from_slice(&body)?; + match body { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(err) => Err(TranscriptionError::ProviderError(err.to_string())), } } else { - Err(TranscriptionError::ProviderError(response.text().await?)) + let text: Vec = response.into_body().await?; + let text = String::from_utf8_lossy(&text).into(); + + Err(TranscriptionError::ProviderError(text)) } } } diff --git a/rig-core/src/transcription.rs b/rig-core/src/transcription.rs index 9b8c8ae9a..24932e820 100644 --- a/rig-core/src/transcription.rs +++ b/rig-core/src/transcription.rs @@ -3,7 +3,7 @@ //! handling transcription responses, and defining transcription models. use crate::client::transcription::TranscriptionModelHandle; -use crate::json_utils; +use crate::{http_client, json_utils}; use futures::future::BoxFuture; use std::sync::Arc; use std::{fs, path::Path}; @@ -15,7 +15,7 @@ use thiserror::Error; pub enum TranscriptionError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), + HttpError(#[from] http_client::Error), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] From c3de5ac86932ac453b156f4ee7cd05d20fcf53e5 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Tue, 30 Sep 2025 17:47:50 -0400 Subject: [PATCH 03/20] More clients --- rig-core/src/client/verify.rs | 2 +- rig-core/src/http_client.rs | 5 + rig-core/src/providers/cohere/client.rs | 15 +- rig-core/src/providers/cohere/embeddings.rs | 2 +- rig-core/src/providers/gemini/client.rs | 5 +- rig-core/src/providers/huggingface/client.rs | 25 ++- .../src/providers/huggingface/completion.rs | 7 +- .../providers/huggingface/image_generation.rs | 6 +- rig-core/src/providers/mistral/client.rs | 162 +++++++++++------ rig-core/src/providers/mistral/completion.rs | 34 ++-- rig-core/src/providers/mistral/embedding.rs | 45 +++-- rig-core/src/providers/openai/client.rs | 163 ++++++++++++------ .../src/providers/openai/completion/mod.rs | 49 ++++-- .../providers/openai/completion/streaming.rs | 4 +- rig-core/src/providers/openai/embedding.rs | 44 +++-- .../src/providers/openai/image_generation.rs | 39 +++-- .../src/providers/openai/responses_api/mod.rs | 43 +++-- .../openai/responses_api/streaming.rs | 4 +- .../src/providers/openai/transcription.rs | 28 +-- 19 files changed, 453 insertions(+), 229 deletions(-) diff --git a/rig-core/src/client/verify.rs b/rig-core/src/client/verify.rs index 3e64a2880..81a565bc3 100644 --- a/rig-core/src/client/verify.rs +++ b/rig-core/src/client/verify.rs @@ -15,7 +15,7 @@ pub enum VerifyError { HttpError( #[from] #[source] - http_client::HttpClientError, + http_client::Error, ), } diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 74a8de967..195ad1d71 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -39,6 +39,11 @@ impl From for Body { } } +pub async fn text(response: Response>>) -> Result { + let text = response.into_body().await?; + Ok(String::from(String::from_utf8_lossy(&text))) +} + pub trait HttpClientExt: Send + Sync { fn request( &self, diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index d85f43375..892b7dd04 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -225,16 +225,25 @@ impl EmbeddingsClient for Client { impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.http_client.get("/v1/models").send().await?; + let response = self + .http_client + .get("/v1/models") + .send() + .await + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + Err(VerifyError::ProviderError(response.text().await.map_err( + |e| VerifyError::HttpError(http_client::Error::Instance(e.into())), + )?)) } _ => { - response.error_for_status()?; + response + .error_for_status() + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; Ok(()) } } diff --git a/rig-core/src/providers/cohere/embeddings.rs b/rig-core/src/providers/cohere/embeddings.rs index 90602ccda..f8a60746a 100644 --- a/rig-core/src/providers/cohere/embeddings.rs +++ b/rig-core/src/providers/cohere/embeddings.rs @@ -93,7 +93,7 @@ where let req = self .client - .post("/v1/embed") + .post::>("/v1/embed") .map_err(|e| EmbeddingError::HttpError(e.into()))? .body(body) .map_err(|e| EmbeddingError::HttpError(e.into()))?; diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index b6281a240..e6838e881 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -300,7 +300,10 @@ where { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let req = self.get("/v1beta/models").body(http_client::NoBody)?; + let req = self + .get("/v1beta/models") + .body(http_client::NoBody) + .map_err(|e| VerifyError::HttpError(e.into()))?; let response = self.http_client.request::<_, Vec>(req).await?; match response.status() { diff --git a/rig-core/src/providers/huggingface/client.rs b/rig-core/src/providers/huggingface/client.rs index e9cced191..5ad78f42d 100644 --- a/rig-core/src/providers/huggingface/client.rs +++ b/rig-core/src/providers/huggingface/client.rs @@ -363,19 +363,34 @@ impl ImageGenerationClient for Client { impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let req = self.get("/api/whoami-v2")?.body(http_client::NoBody)?; - let req = reqwest::Request::try_from(req)?; - let response: reqwest::Response = self.http_client.execute(req).await?; + let req = self + .get("/api/whoami-v2")? + .body(http_client::NoBody) + .map_err(|e| VerifyError::HttpError(e.into()))?; + + let req = reqwest::Request::try_from(req) + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; + + let response: reqwest::Response = self + .http_client + .execute(req) + .await + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - let text = response.text().await?; + let text = response + .text() + .await + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + response + .error_for_status() + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; Ok(()) } } diff --git a/rig-core/src/providers/huggingface/completion.rs b/rig-core/src/providers/huggingface/completion.rs index cd4bb2832..be39ea2cf 100644 --- a/rig-core/src/providers/huggingface/completion.rs +++ b/rig-core/src/providers/huggingface/completion.rs @@ -603,12 +603,13 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())), } } else { + let status = response.status(); let text: Vec = response.into_body().await?; - let text = String::from_utf8_lossy(&text).into(); + let text: String = String::from_utf8_lossy(&text).into(); + Err(CompletionError::ProviderError(format!( "{}: {}", - response.status(), - text + status, text ))) } } diff --git a/rig-core/src/providers/huggingface/image_generation.rs b/rig-core/src/providers/huggingface/image_generation.rs index 31ea98406..458338a08 100644 --- a/rig-core/src/providers/huggingface/image_generation.rs +++ b/rig-core/src/providers/huggingface/image_generation.rs @@ -77,13 +77,13 @@ where let response = self.client.send(req).await?; if !response.status().is_success() { + let status = response.status(); let text: Vec = response.into_body().await?; - let text = String::from_utf8_lossy(&text).into(); + let text: String = String::from_utf8_lossy(&text).into(); return Err(ImageGenerationError::ProviderError(format!( "{}: {}", - response.status(), - text + status, text ))); } diff --git a/rig-core/src/providers/mistral/client.rs b/rig-core/src/providers/mistral/client.rs index 60a9a97b3..d88e47523 100644 --- a/rig-core/src/providers/mistral/client.rs +++ b/rig-core/src/providers/mistral/client.rs @@ -1,65 +1,72 @@ +use bytes::Bytes; use serde::{Deserialize, Serialize}; use super::{ CompletionModel, embedding::{EmbeddingModel, MISTRAL_EMBED}, }; -use crate::client::{ - ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, - VerifyError, +use crate::{ + client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError}, + http_client::HttpClientExt, }; -use crate::impl_conversion_traits; +use crate::{http_client, impl_conversion_traits}; +use std::fmt::Debug; const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: MISTRAL_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} - pub fn base_url(mut self, base_url: &'a str) -> Self { - self.base_url = base_url; - self +impl<'a, T> ClientBuilder<'a, T> { + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; self } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -69,7 +76,23 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ + /// Create a new Mistral client. For more control, use the `builder` method. + /// + /// # Panics + /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). + pub fn new(api_key: &str) -> Client { + Self::builder(api_key).build() + } +} + +impl Client +where + T: Default, +{ /// Create a new Mistral client builder. /// /// # Example @@ -80,32 +103,49 @@ impl Client { /// let mistral = Client::builder("your-mistral-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } +} - /// Create a new Mistral client. For more control, use the `builder` method. - /// - /// # Panics - /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). - pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Mistral client should build") +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + Ok(http_client::Request::post(url).header("Authorization", auth_header)) } - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + Ok(http_client::Request::get(url).header("Authorization", auth_header)) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + pub(crate) async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + Body: Into, + R: From + Send, + { + self.http_client.request(req).await } } -impl ProviderClient for Client { +impl ProviderClient for Client +where + T: HttpClientExt + Debug + Default + Clone + 'static, +{ /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self @@ -124,8 +164,11 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client +where + T: HttpClientExt + Debug + Default + Clone + 'static, +{ + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -143,8 +186,11 @@ impl CompletionClient for Client { } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client +where + T: HttpClientExt + Debug + Default + Clone + 'static, +{ + type EmbeddingModel = EmbeddingModel; /// Create an embedding model with the given name. /// Note: default embedding dimension of 0 will be used if model is not known. @@ -158,7 +204,7 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED); /// ``` - fn embedding_model(&self, model: &str) -> EmbeddingModel { + fn embedding_model(&self, model: &str) -> EmbeddingModel { let ndims = match model { MISTRAL_EMBED => 1024, _ => 0, @@ -171,25 +217,37 @@ impl EmbeddingsClient for Client { } } -impl VerifyClient for Client { +impl VerifyClient for Client +where + T: HttpClientExt + Debug + Default + Clone + 'static, +{ #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/models").send().await?; + let req = self + .get("/models")? + .body(http_client::NoBody) + .map_err(|e| VerifyError::HttpError(e.into()))?; + + let response = self.http_client.request(req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + let text: Vec = response.into_body().await?; + let text = String::from_utf8_lossy(&text).into(); + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + // TODO: implement equivalent with `http` crate `StatusCode` type + //response.error_for_status()?; Ok(()) } } } } -impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client); +impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client); #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Usage { diff --git a/rig-core/src/providers/mistral/completion.rs b/rig-core/src/providers/mistral/completion.rs index 2a4ea77d1..3eefacb5a 100644 --- a/rig-core/src/providers/mistral/completion.rs +++ b/rig-core/src/providers/mistral/completion.rs @@ -5,6 +5,7 @@ use std::{convert::Infallible, str::FromStr}; use super::client::{Client, Usage}; use crate::completion::GetTokenUsage; +use crate::http_client::{self, HttpClientExt}; use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse}; use crate::{ OneOrMany, @@ -250,13 +251,13 @@ impl FromStr for AssistantContent { } #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -411,7 +412,10 @@ impl TryFrom for completion::CompletionResponse completion::CompletionModel for CompletionModel +where + T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static, +{ type Response = CompletionResponse; type StreamingResponse = CompletionResponse; @@ -420,17 +424,20 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(completion_request)?; + let body = self.create_completion_request(completion_request)?; + let body = serde_json::to_vec(&body)?; - let response = self + let request = self .client - .post("v1/chat/completions") - .json(&request) - .send() - .await?; + .post("v1/chat/completions")? + .body(body) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self.client.send(request).await?; if response.status().is_success() { - let text = response.text().await?; + let text = http_client::text(response).await?; + match serde_json::from_str::>(&text)? { ApiResponse::Ok(response) => { tracing::debug!(target: "rig", @@ -442,7 +449,8 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) } } diff --git a/rig-core/src/providers/mistral/embedding.rs b/rig-core/src/providers/mistral/embedding.rs index 7adb09246..fa9557256 100644 --- a/rig-core/src/providers/mistral/embedding.rs +++ b/rig-core/src/providers/mistral/embedding.rs @@ -1,7 +1,10 @@ use serde::Deserialize; use serde_json::json; -use crate::embeddings::{self, EmbeddingError}; +use crate::{ + embeddings::{self, EmbeddingError}, + http_client::{self, HttpClientExt}, +}; use super::client::{ApiResponse, Client, Usage}; @@ -13,14 +16,14 @@ pub const MISTRAL_EMBED: &str = "mistral-embed"; pub const MAX_DOCUMENTS: usize = 1024; #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, pub model: String, ndims: usize, } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: usize) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), @@ -29,7 +32,10 @@ impl EmbeddingModel { } } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Clone, +{ const MAX_DOCUMENTS: usize = MAX_DOCUMENTS; fn ndims(&self) -> usize { self.ndims @@ -42,18 +48,24 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ) -> Result, EmbeddingError> { let documents = documents.into_iter().collect::>(); - let response = self + let body = serde_json::to_vec(&json!({ + "model": self.model, + "input": documents + }))?; + + let req = self .client - .post("v1/embeddings") - .json(&json!({ - "model": self.model, - "input": documents, - })) - .send() - .await?; + .post("v1/embeddings")? + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - match response.json::>().await? { + let body: Vec = response.into_body().await?; + let body: ApiResponse = serde_json::from_slice(&body)?; + + match body { ApiResponse::Ok(response) => { tracing::debug!(target: "rig", "Mistral embedding token usage: {}", @@ -79,7 +91,8 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(EmbeddingError::ProviderError(text)) } } } diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index ae01f7827..2c254f249 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -3,6 +3,7 @@ use super::audio_generation::AudioGenerationModel; use super::embedding::{ EmbeddingModel, TEXT_EMBEDDING_3_LARGE, TEXT_EMBEDDING_3_SMALL, TEXT_EMBEDDING_ADA_002, }; +use std::fmt::Debug; #[cfg(feature = "image")] use super::image_generation::ImageGenerationModel; @@ -10,10 +11,11 @@ use super::transcription::TranscriptionModel; use crate::{ client::{ - ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, - TranscriptionClient, VerifyClient, VerifyError, + CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient, + VerifyError, }, extractor::ExtractorBuilder, + http_client::{self, HttpClientExt}, providers::openai::CompletionModel, }; @@ -22,6 +24,7 @@ use crate::client::AudioGenerationClient; #[cfg(feature = "image")] use crate::client::ImageGenerationClient; +use bytes::Bytes; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -30,54 +33,58 @@ use serde::{Deserialize, Serialize}; // ================================================================ const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: OPENAI_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl Debug for Client +where + T: Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -87,7 +94,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new OpenAI client builder. /// /// # Example @@ -98,42 +108,81 @@ impl Client { /// let openai_client = Client::builder("your-open-ai-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } /// Create a new OpenAI client. For more control, use the `builder` method. /// - /// # Panics - /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("OpenAI client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + Ok(http_client::Request::post(url).header("Authorization", auth_header)) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + Ok(http_client::Request::get(url).header("Authorization", auth_header)) + } + + pub(crate) async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + R: From + Send, + { + self.http_client.request(req).await + } +} + +impl Client { + pub(crate) async fn send_reqwest( + &self, + req: reqwest::Request, + ) -> reqwest::Result { + self.http_client.execute(req).await + } + + pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + self.http_client.post(url).bearer_auth(&self.api_key) } /// Create an extractor builder with the given completion model. /// Intended for use exclusively with the Chat Completions API. /// Useful for using extractors with Chat Completion compliant APIs. - pub fn extractor_completions_api(&self, model: &str) -> ExtractorBuilder + pub fn extractor_completions_api( + &self, + model: &str, + ) -> ExtractorBuilder, U> where - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, + U: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, { ExtractorBuilder::new(self.completion_model(model).completions_api()) } } -impl ProviderClient for Client { +impl Client {} + +impl ProviderClient for Client { /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -141,7 +190,7 @@ impl ProviderClient for Client { let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); match base_url { - Some(url) => Self::builder(&api_key).base_url(&url).build().unwrap(), + Some(url) => Self::builder(&api_key).base_url(&url).build(), None => Self::new(&api_key), } } @@ -154,8 +203,8 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = super::responses_api::ResponsesCompletionModel; +impl CompletionClient for Client { + type CompletionModel = super::responses_api::ResponsesCompletionModel; /// Create a completion model with the given name. /// /// # Example @@ -167,13 +216,16 @@ impl CompletionClient for Client { /// /// let gpt4 = openai.completion_model(openai::GPT_4); /// ``` - fn completion_model(&self, model: &str) -> super::responses_api::ResponsesCompletionModel { + fn completion_model( + &self, + model: &str, + ) -> super::responses_api::ResponsesCompletionModel { super::responses_api::ResponsesCompletionModel::new(self.clone(), model) } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; fn embedding_model(&self, model: &str) -> Self::EmbeddingModel { let ndims = match model { TEXT_EMBEDDING_3_LARGE => 3072, @@ -188,8 +240,8 @@ impl EmbeddingsClient for Client { } } -impl TranscriptionClient for Client { - type TranscriptionModel = TranscriptionModel; +impl TranscriptionClient for Client { + type TranscriptionModel = TranscriptionModel; /// Create a transcription model with the given name. /// /// # Example @@ -201,14 +253,14 @@ impl TranscriptionClient for Client { /// /// let gpt4 = openai.transcription_model(openai::WHISPER_1); /// ``` - fn transcription_model(&self, model: &str) -> TranscriptionModel { + fn transcription_model(&self, model: &str) -> TranscriptionModel { TranscriptionModel::new(self.clone(), model) } } #[cfg(feature = "image")] -impl ImageGenerationClient for Client { - type ImageGenerationModel = ImageGenerationModel; +impl ImageGenerationClient for Client { + type ImageGenerationModel = ImageGenerationModel; /// Create an image generation model with the given name. /// /// # Example @@ -244,18 +296,25 @@ impl AudioGenerationClient for Client { } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/models").send().await?; + let req = self + .get("/models")? + .body(http_client::NoBody) + .map_err(|e| VerifyError::HttpError(e.into()))?; + + let response = self.send(req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } diff --git a/rig-core/src/providers/openai/completion/mod.rs b/rig-core/src/providers/openai/completion/mod.rs index 797f29bbe..570e72593 100644 --- a/rig-core/src/providers/openai/completion/mod.rs +++ b/rig-core/src/providers/openai/completion/mod.rs @@ -4,9 +4,10 @@ use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse}; use crate::completion::{CompletionError, CompletionRequest}; +use crate::http_client::HttpClientExt; use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType}; use crate::one_or_many::string_or_one_or_many; -use crate::{OneOrMany, completion, json_utils, message}; +use crate::{OneOrMany, completion, http_client, json_utils, message}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::convert::Infallible; @@ -711,24 +712,23 @@ impl fmt::Display for Usage { } #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel +where + T: HttpClientExt + Default + std::fmt::Debug + Clone + 'static, +{ + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), } } - pub fn into_agent_builder(self) -> crate::agent::AgentBuilder { - crate::agent::AgentBuilder::new(self) - } - pub(crate) fn create_completion_request( &self, completion_request: CompletionRequest, @@ -794,7 +794,13 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl CompletionModel { + pub fn into_agent_builder(self) -> crate::agent::AgentBuilder { + crate::agent::AgentBuilder::new(self) + } +} + +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -810,18 +816,22 @@ impl completion::CompletionModel for CompletionModel { request = serde_json::to_string_pretty(&request).unwrap() ); - let response = self + let body = serde_json::to_vec(&request)?; + + let req = self .client - .post("/chat/completions") - .json(&request) - .send() - .await?; + .post("/chat/completions")? + .body(body) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "OpenAI completion error: {}", t); + let text = http_client::text(response).await?; + + tracing::debug!(target: "rig", "OpenAI completion error: {}", text); - match serde_json::from_str::>(&t)? { + match serde_json::from_str::>(&text)? { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenAI completion token usage: {:?}", @@ -832,7 +842,8 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) } } diff --git a/rig-core/src/providers/openai/completion/streaming.rs b/rig-core/src/providers/openai/completion/streaming.rs index fd23d4ade..94313002e 100644 --- a/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig-core/src/providers/openai/completion/streaming.rs @@ -66,7 +66,7 @@ impl GetTokenUsage for StreamingCompletionResponse { } } -impl CompletionModel { +impl CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, @@ -78,7 +78,7 @@ impl CompletionModel { json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.post_reqwest("/chat/completions").json(&request); send_compatible_streaming_request(builder).await } } diff --git a/rig-core/src/providers/openai/embedding.rs b/rig-core/src/providers/openai/embedding.rs index 9a1fc7ac4..6458aaf59 100644 --- a/rig-core/src/providers/openai/embedding.rs +++ b/rig-core/src/providers/openai/embedding.rs @@ -1,6 +1,7 @@ use super::{ApiErrorResponse, ApiResponse, Client, completion::Usage}; -use crate::embeddings; use crate::embeddings::EmbeddingError; +use crate::http_client::HttpClientExt; +use crate::{embeddings, http_client}; use serde::Deserialize; use serde_json::json; @@ -45,13 +46,16 @@ pub struct EmbeddingData { } #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, pub model: String, ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static, +{ const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -65,18 +69,27 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ) -> Result, EmbeddingError> { let documents = documents.into_iter().collect::>(); - let response = self + let body = json!({ + "model": self.model, + "input": documents, + }); + + let body = serde_json::to_vec(&body)?; + + let req = self .client .post("/embeddings") - .json(&json!({ - "model": self.model, - "input": documents, - })) - .send() - .await?; + .map_err(|e| EmbeddingError::HttpError(e.into()))? + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - match response.json::>().await? { + let body: Vec = response.into_body().await?; + let body: ApiResponse = serde_json::from_slice(&body)?; + + match body { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenAI embedding token usage: {:?}", @@ -102,13 +115,14 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(EmbeddingError::ProviderError(text)) } } } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: usize) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), diff --git a/rig-core/src/providers/openai/image_generation.rs b/rig-core/src/providers/openai/image_generation.rs index 69ae69d27..aa2b2c9d8 100644 --- a/rig-core/src/providers/openai/image_generation.rs +++ b/rig-core/src/providers/openai/image_generation.rs @@ -1,7 +1,8 @@ -use crate::image_generation; +use crate::http_client::HttpClientExt; use crate::image_generation::{ImageGenerationError, ImageGenerationRequest}; use crate::json_utils::merge_inplace; use crate::providers::openai::{ApiResponse, Client}; +use crate::{http_client, image_generation}; use base64::Engine; use base64::prelude::BASE64_STANDARD; use serde::Deserialize; @@ -46,14 +47,14 @@ impl TryFrom } #[derive(Clone)] -pub struct ImageGenerationModel { - client: Client, +pub struct ImageGenerationModel { + client: Client, /// Name of the model (e.g.: dall-e-2) pub model: String, } -impl ImageGenerationModel { - pub(crate) fn new(client: Client, model: &str) -> Self { +impl ImageGenerationModel { + pub(crate) fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -61,7 +62,10 @@ impl ImageGenerationModel { } } -impl image_generation::ImageGenerationModel for ImageGenerationModel { +impl image_generation::ImageGenerationModel for ImageGenerationModel +where + T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static, +{ type Response = ImageGenerationResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -85,24 +89,29 @@ impl image_generation::ImageGenerationModel for ImageGenerationModel { ); } - let response = self + let body = serde_json::to_vec(&request)?; + + let request = self .client - .post("/images/generations") - .json(&request) - .send() - .await?; + .post("/images/generations")? + .body(body) + .map_err(|e| ImageGenerationError::HttpError(e.into()))?; + + let response = self.client.send(request).await?; if !response.status().is_success() { + let status = response.status(); + let text = http_client::text(response).await?; + return Err(ImageGenerationError::ProviderError(format!( "{}: {}", - response.status(), - response.text().await? + status, text, ))); } - let t = response.text().await?; + let text = http_client::text(response).await?; - match serde_json::from_str::>(&t)? { + match serde_json::from_str::>(&text)? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)), } diff --git a/rig-core/src/providers/openai/responses_api/mod.rs b/rig-core/src/providers/openai/responses_api/mod.rs index d358c6885..ee0b7a500 100644 --- a/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig-core/src/providers/openai/responses_api/mod.rs @@ -10,9 +10,10 @@ use super::{Client, responses_api::streaming::StreamingCompletionResponse}; use super::{ImageUrl, InputAudio, SystemContent}; use crate::completion::CompletionError; -use crate::json_utils; +use crate::http_client::HttpClientExt; use crate::message::{AudioMediaType, Document, DocumentSourceKind, MessageError, MimeType, Text}; use crate::one_or_many::string_or_one_or_many; +use crate::{http_client, json_utils}; use crate::{OneOrMany, completion, message}; use serde::{Deserialize, Serialize}; @@ -629,16 +630,19 @@ impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionReque /// The completion model struct for OpenAI's response API. #[derive(Clone)] -pub struct ResponsesCompletionModel { +pub struct ResponsesCompletionModel { /// The OpenAI client - pub(crate) client: Client, + pub(crate) client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl ResponsesCompletionModel { +impl ResponsesCompletionModel +where + T: HttpClientExt + Clone + Default + std::fmt::Debug + 'static, +{ /// Creates a new [`ResponsesCompletionModel`]. - pub fn new(client: Client, model: &str) -> Self { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -646,7 +650,7 @@ impl ResponsesCompletionModel { } /// Use the Completions API instead of Responses. - pub fn completions_api(self) -> crate::providers::openai::completion::CompletionModel { + pub fn completions_api(self) -> crate::providers::openai::completion::CompletionModel { crate::providers::openai::completion::CompletionModel::new(self.client, &self.model) } @@ -967,7 +971,7 @@ pub enum OutputRole { Assistant, } -impl completion::CompletionModel for ResponsesCompletionModel { +impl completion::CompletionModel for ResponsesCompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -976,21 +980,28 @@ impl completion::CompletionModel for ResponsesCompletionModel { &self, completion_request: crate::completion::CompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(completion_request)?; - let request = serde_json::to_value(request)?; + let body = self.create_completion_request(completion_request)?; + tracing::debug!("OpenAI input: {}", serde_json::to_string_pretty(&body)?); - tracing::debug!("OpenAI input: {}", serde_json::to_string_pretty(&request)?); + let body = serde_json::to_vec(&body)?; - let response = self.client.post("/responses").json(&request).send().await?; + let req = self + .client + .post("/responses")? + .body(body) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "OpenAI response: {}", t); + let text = http_client::text(response).await?; + tracing::debug!(target: "rig", "OpenAI response: {}", text); - let response = serde_json::from_str::(&t)?; + let response = serde_json::from_str::(&text)?; response.try_into() } else { - Err(CompletionError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) } } @@ -1002,7 +1013,7 @@ impl completion::CompletionModel for ResponsesCompletionModel { crate::streaming::StreamingCompletionResponse, CompletionError, > { - Self::stream(self, request).await + ResponsesCompletionModel::stream(self, request).await } } diff --git a/rig-core/src/providers/openai/responses_api/streaming.rs b/rig-core/src/providers/openai/responses_api/streaming.rs index dc2c836f3..d914f43af 100644 --- a/rig-core/src/providers/openai/responses_api/streaming.rs +++ b/rig-core/src/providers/openai/responses_api/streaming.rs @@ -191,7 +191,7 @@ pub enum SummaryPartChunkPart { SummaryText { text: String }, } -impl ResponsesCompletionModel { +impl ResponsesCompletionModel { pub(crate) async fn stream( &self, completion_request: crate::completion::CompletionRequest, @@ -202,7 +202,7 @@ impl ResponsesCompletionModel { tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?); - let builder = self.client.post("/responses").json(&request); + let builder = self.client.post_reqwest("/responses").json(&request); send_compatible_streaming_request(builder).await } } diff --git a/rig-core/src/providers/openai/transcription.rs b/rig-core/src/providers/openai/transcription.rs index 9b001f6a2..efa8615ad 100644 --- a/rig-core/src/providers/openai/transcription.rs +++ b/rig-core/src/providers/openai/transcription.rs @@ -1,3 +1,4 @@ +use crate::http_client::{self, HttpClientExt}; use crate::providers::openai::{ApiResponse, Client}; use crate::transcription; use crate::transcription::TranscriptionError; @@ -28,14 +29,14 @@ impl TryFrom } #[derive(Clone)] -pub struct TranscriptionModel { - client: Client, +pub struct TranscriptionModel { + client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl TranscriptionModel { - pub fn new(client: Client, model: &str) -> Self { +impl TranscriptionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -43,7 +44,7 @@ impl TranscriptionModel { } } -impl transcription::TranscriptionModel for TranscriptionModel { +impl transcription::TranscriptionModel for TranscriptionModel { type Response = TranscriptionResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -83,23 +84,30 @@ impl transcription::TranscriptionModel for TranscriptionModel { let response = self .client - .post("audio/transcriptions") + .post_reqwest("audio/transcriptions") .multipart(body) .send() - .await?; + .await + .map_err(|e| TranscriptionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { match response .json::>() - .await? - { + .await + .map_err(|e| { + TranscriptionError::HttpError(http_client::Error::Instance(e.into())) + })? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError( api_error_response.message, )), } } else { - Err(TranscriptionError::ProviderError(response.text().await?)) + Err(TranscriptionError::ProviderError( + response.text().await.map_err(|e| { + TranscriptionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } } From eaefd47b6962c78f493735ca7b6fc65869fade93 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 2 Oct 2025 16:39:37 -0400 Subject: [PATCH 04/20] Four clients to go --- rig-core/src/audio_generation.rs | 4 +- rig-core/src/http_client.rs | 2 +- rig-core/src/providers/azure.rs | 284 +++++++++++------- rig-core/src/providers/deepseek.rs | 157 +++++++--- rig-core/src/providers/galadriel.rs | 169 +++++++---- rig-core/src/providers/groq.rs | 183 +++++++---- rig-core/src/providers/hyperbolic.rs | 231 +++++++++----- rig-core/src/providers/mira.rs | 165 ++++++---- rig-core/src/providers/openrouter/client.rs | 117 +++++--- .../src/providers/openrouter/completion.rs | 28 +- .../src/providers/openrouter/streaming.rs | 18 +- rig-core/src/providers/together/client.rs | 160 +++++++--- rig-core/src/providers/together/completion.rs | 33 +- rig-core/src/providers/together/embedding.rs | 48 ++- rig-core/src/providers/together/streaming.rs | 7 +- rig-core/src/providers/xai/client.rs | 140 ++++++--- rig-core/src/providers/xai/completion.rs | 31 +- rig-core/src/providers/xai/streaming.rs | 7 +- 18 files changed, 1177 insertions(+), 607 deletions(-) diff --git a/rig-core/src/audio_generation.rs b/rig-core/src/audio_generation.rs index 5d14d034a..4e459b3e9 100644 --- a/rig-core/src/audio_generation.rs +++ b/rig-core/src/audio_generation.rs @@ -1,6 +1,6 @@ //! Everything related to audio generation (ie, Text To Speech). //! Rig abstracts over a number of different providers using the [AudioGenerationModel] trait. -use crate::client::audio_generation::AudioGenerationModelHandle; +use crate::{client::audio_generation::AudioGenerationModelHandle, http_client}; use futures::future::BoxFuture; use serde_json::Value; use std::sync::Arc; @@ -10,7 +10,7 @@ use thiserror::Error; pub enum AudioGenerationError { /// Http error (e.g.: connection error, timeout, etc.) #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), + HttpError(#[from] http_client::Error), /// Json error (e.g.: serialization, deserialization) #[error("JsonError: {0}")] diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 195ad1d71..479a3470b 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use futures::stream::{BoxStream, StreamExt}; -pub use http::{HeaderValue, Method, Request, Response, Uri, request::Builder}; +pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder}; use reqwest::Body; use std::future::Future; use std::pin::Pin; diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index 9488ab62b..e3af36f1b 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -11,6 +11,7 @@ use super::openai::{TranscriptionResponse, send_compatible_streaming_request}; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::streaming::StreamingCompletionResponse; use crate::{ @@ -21,6 +22,7 @@ use crate::{ providers::openai, transcription::{self, TranscriptionError}, }; +use bytes::Bytes; use reqwest::header::AUTHORIZATION; use reqwest::multipart::Part; use serde::Deserialize; @@ -31,23 +33,28 @@ use serde_json::json; const DEFAULT_API_VERSION: &str = "2024-10-21"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { auth: AzureOpenAIAuth, api_version: Option<&'a str>, azure_endpoint: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(auth: impl Into, endpoint: &'a str) -> Self { Self { auth: auth.into(), api_version: None, azure_endpoint: endpoint, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview) pub fn api_version(mut self, api_version: &'a str) -> Self { self.api_version = Some(api_version); @@ -60,38 +67,39 @@ impl<'a> ClientBuilder<'a> { self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + auth: self.auth, + api_version: self.api_version, + azure_endpoint: self.azure_endpoint, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - + pub fn build(self) -> Client { let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION); - Ok(Client { + Client { api_version: api_version.to_string(), azure_endpoint: self.azure_endpoint.to_string(), auth: self.auth, - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { api_version: String, azure_endpoint: String, auth: AzureOpenAIAuth, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("azure_endpoint", &self.azure_endpoint) @@ -140,7 +148,10 @@ impl AzureOpenAIAuth { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Azure OpenAI client builder. /// /// # Example @@ -151,29 +162,64 @@ impl Client { /// let azure = Client::builder("your-azure-api-key", "https://{your-resource-name}.openai.azure.com") /// .build() /// ``` - pub fn builder(auth: impl Into, endpoint: &str) -> ClientBuilder<'_> { + pub fn builder(auth: impl Into, endpoint: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(auth, endpoint) } /// Creates a new Azure OpenAI client. For more control, use the `builder` method. - /// - /// # Panics - /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(auth: impl Into, endpoint: &str) -> Self { - Self::builder(auth, endpoint) - .build() - .expect("Azure OpenAI client should build") + Self::builder(auth, endpoint).build() + } +} + +impl Client +where + T: HttpClientExt, +{ + fn post(&self, url: String) -> http_client::Builder { + let (key, value) = self.auth.as_header(); + + http_client::Request::post(url).header(key, value) } - fn post_embedding(&self, deployment_id: &str) -> reqwest::RequestBuilder { + fn post_embedding(&self, deployment_id: &str) -> http_client::Builder { let url = format!( "{}/openai/deployments/{}/embeddings?api-version={}", self.azure_endpoint, deployment_id, self.api_version ) .replace("//", "/"); - let (key, value) = self.auth.as_header(); - self.http_client.post(url).header(key, value) + self.post(url) + } + + #[cfg(feature = "audio")] + fn post_audio_generation(&self, deployment_id: &str) -> http_client::Builder { + let url = format!( + "{}/openai/deployments/{}/audio/speech?api-version={}", + self.azure_endpoint, deployment_id, self.api_version + ) + .replace("//", "/"); + + self.post(url) + } + + async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + R: From + Send, + { + self.http_client.request(req).await + } +} + +impl Client { + fn post_reqwest(&self, url: String) -> reqwest::RequestBuilder { + let (key, val) = self.auth.as_header(); + + self.http_client.post(url).header(key, val) } fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder { @@ -182,8 +228,8 @@ impl Client { self.azure_endpoint, deployment_id, self.api_version ) .replace("//", "/"); - let (key, value) = self.auth.as_header(); - self.http_client.post(url).header(key, value) + + self.post_reqwest(url) } fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder { @@ -192,8 +238,8 @@ impl Client { self.azure_endpoint, deployment_id, self.api_version ) .replace("//", "/"); - let (key, value) = self.auth.as_header(); - self.http_client.post(url).header(key, value) + + self.post_reqwest(url) } #[cfg(feature = "image")] @@ -203,23 +249,12 @@ impl Client { self.azure_endpoint, deployment_id, self.api_version ) .replace("//", "/"); - let (key, value) = self.auth.as_header(); - self.http_client.post(url).header(key, value) - } - #[cfg(feature = "audio")] - fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder { - let url = format!( - "{}/openai/deployments/{}/audio/speech?api-version={}", - self.azure_endpoint, deployment_id, self.api_version - ) - .replace("//", "/"); - let (key, value) = self.auth.as_header(); - self.http_client.post(url).header(key, value) + self.post_reqwest(url) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables. fn from_env() -> Self { let auth = if let Ok(api_key) = std::env::var("AZURE_API_KEY") { @@ -236,7 +271,6 @@ impl ProviderClient for Client { Self::builder(auth, &azure_endpoint) .api_version(&api_version) .build() - .expect("Azure OpenAI client should build") } fn from_val(input: crate::client::ProviderValue) -> Self { @@ -246,15 +280,12 @@ impl ProviderClient for Client { panic!("Incorrect provider value type") }; let auth = AzureOpenAIAuth::ApiKey(api_key.to_string()); - Self::builder(auth, &header) - .api_version(&version) - .build() - .expect("Azure OpenAI client should build") + Self::builder(auth, &header).api_version(&version).build() } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -267,13 +298,13 @@ impl CompletionClient for Client { /// /// let gpt4 = azure.completion_model(azure::GPT_4); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; /// Create an embedding model with the given name. /// Note: default embedding dimension of 0 will be used if model is not known. @@ -288,7 +319,7 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE); /// ``` - fn embedding_model(&self, model: &str) -> EmbeddingModel { + fn embedding_model(&self, model: &str) -> EmbeddingModel { let ndims = match model { TEXT_EMBEDDING_3_LARGE => 3072, TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536, @@ -308,13 +339,17 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072); /// ``` - fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + fn embedding_model_with_ndims( + &self, + model: &str, + ndims: usize, + ) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, ndims) } } -impl TranscriptionClient for Client { - type TranscriptionModel = TranscriptionModel; +impl TranscriptionClient for Client { + type TranscriptionModel = TranscriptionModel; /// Create a transcription model with the given name. /// @@ -327,7 +362,7 @@ impl TranscriptionClient for Client { /// /// let whisper = azure.transcription_model("model-unknown-to-rig"); /// ``` - fn transcription_model(&self, model: &str) -> TranscriptionModel { + fn transcription_model(&self, model: &str) -> TranscriptionModel { TranscriptionModel::new(self.clone(), model) } } @@ -401,13 +436,16 @@ impl std::fmt::Display for Usage { } #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, pub model: String, ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Default + Clone, +{ const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -421,17 +459,23 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ) -> Result, EmbeddingError> { let documents = documents.into_iter().collect::>(); - let response = self + let body = serde_json::to_vec(&json!({ + "input": documents, + }))?; + + let req = self .client .post_embedding(&self.model) - .json(&json!({ - "input": documents, - })) - .send() - .await?; + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - match response.json::>().await? { + let body: Vec = response.into_body().await?; + let body: ApiResponse = serde_json::from_slice(&body)?; + + match body { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "Azure embedding token usage: {}", @@ -457,13 +501,14 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(EmbeddingError::ProviderError(text)) } } } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: usize) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), @@ -503,14 +548,14 @@ pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k"; #[derive(Clone)] -pub struct CompletionModel { - client: Client, +pub struct CompletionModel { + client: Client, /// Name of the model (e.g.: gpt-4o-mini) pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -566,7 +611,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = openai::CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -582,10 +627,15 @@ impl completion::CompletionModel for CompletionModel { .post_chat_completion(&self.model) .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - let t = response.text().await?; + let t = response + .text() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + tracing::debug!(target: "rig", "Azure completion error: {}", t); match serde_json::from_str::>(&t)? { @@ -599,7 +649,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } @@ -629,14 +683,14 @@ impl completion::CompletionModel for CompletionModel { // ================================================================ #[derive(Clone)] -pub struct TranscriptionModel { - client: Client, +pub struct TranscriptionModel { + client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl TranscriptionModel { - pub fn new(client: Client, model: &str) -> Self { +impl TranscriptionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -644,7 +698,7 @@ impl TranscriptionModel { } } -impl transcription::TranscriptionModel for TranscriptionModel { +impl transcription::TranscriptionModel for TranscriptionModel { type Response = TranscriptionResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -684,20 +738,27 @@ impl transcription::TranscriptionModel for TranscriptionModel { .post_transcription(&self.model) .multipart(body) .send() - .await?; + .await + .map_err(|e| TranscriptionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { match response .json::>() - .await? - { + .await + .map_err(|e| { + TranscriptionError::HttpError(http_client::Error::Instance(e.into())) + })? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError( api_error_response.message, )), } } else { - Err(TranscriptionError::ProviderError(response.text().await?)) + Err(TranscriptionError::ProviderError( + response.text().await.map_err(|e| { + TranscriptionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } } @@ -711,18 +772,18 @@ pub use image_generation::*; #[cfg_attr(docsrs, doc(cfg(feature = "image")))] mod image_generation { use crate::client::ImageGenerationClient; - use crate::image_generation; use crate::image_generation::{ImageGenerationError, ImageGenerationRequest}; use crate::providers::azure::{ApiResponse, Client}; use crate::providers::openai::ImageGenerationResponse; + use crate::{http_client, image_generation}; use serde_json::json; #[derive(Clone)] - pub struct ImageGenerationModel { - client: Client, + pub struct ImageGenerationModel { + client: Client, pub model: String, } - impl image_generation::ImageGenerationModel for ImageGenerationModel { + impl image_generation::ImageGenerationModel for ImageGenerationModel { type Response = ImageGenerationResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -743,17 +804,24 @@ mod image_generation { .post_image_generation(&self.model) .json(&request) .send() - .await?; + .await + .map_err(|e| { + ImageGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?; if !response.status().is_success() { return Err(ImageGenerationError::ProviderError(format!( "{}: {}", response.status(), - response.text().await? + response.text().await.map_err(|e| { + ImageGenerationError::HttpError(http_client::Error::Instance(e.into())) + })? ))); } - let t = response.text().await?; + let t = response.text().await.map_err(|e| { + ImageGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?; match serde_json::from_str::>(&t)? { ApiResponse::Ok(response) => response.try_into(), @@ -762,8 +830,8 @@ mod image_generation { } } - impl ImageGenerationClient for Client { - type ImageGenerationModel = ImageGenerationModel; + impl ImageGenerationClient for Client { + type ImageGenerationModel = ImageGenerationModel; fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel { ImageGenerationModel { @@ -797,12 +865,12 @@ mod audio_generation { use serde_json::json; #[derive(Clone)] - pub struct AudioGenerationModel { - client: Client, + pub struct AudioGenerationModel { + client: Client, model: String, } - impl audio_generation::AudioGenerationModel for AudioGenerationModel { + impl audio_generation::AudioGenerationModel for AudioGenerationModel { type Response = Bytes; #[cfg_attr(feature = "worker", worker::send)] @@ -841,8 +909,8 @@ mod audio_generation { } } - impl AudioGenerationClient for Client { - type AudioGenerationModel = AudioGenerationModel; + impl AudioGenerationClient for Client { + type AudioGenerationModel = AudioGenerationModel; fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel { AudioGenerationModel { @@ -853,7 +921,7 @@ mod audio_generation { } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { // There is currently no way to verify the Azure OpenAI API key or token without diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 63501083e..7166f94d3 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -10,6 +10,7 @@ //! ``` use async_stream::stream; +use bytes::Bytes; use futures::StreamExt; use reqwest_eventsource::{Event, RequestBuilderExt}; use std::collections::HashMap; @@ -18,6 +19,7 @@ use crate::client::{ ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, }; use crate::completion::GetTokenUsage; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::message::{Document, DocumentSourceKind}; use crate::{ @@ -36,54 +38,59 @@ use super::openai::StreamingToolCall; // ================================================================ const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: DEEPSEEK_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { pub base_url: String, api_key: String, - http_client: HttpClient, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -93,7 +100,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new DeepSeek client builder. /// /// # Example @@ -104,7 +114,7 @@ impl Client { /// let deepseek = Client::builder("your-deepseek-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -113,23 +123,59 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("DeepSeek client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + fn req( + &self, + method: http_client::Method, + path: &str, + ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + Ok(http_client::Request::builder() + .method(method) + .uri(url) + .header("Authorization", auth_header)) + } + + pub(crate) fn post(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::GET, path) + } + + async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + R: From + Send, + { + self.http_client.request(req).await + } +} + +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + + self.http_client.post(url).bearer_auth(&self.api_key) } } -impl ProviderClient for Client { +impl ProviderClient for Client { // If you prefer the environment variable approach: fn from_env() -> Self { let api_key = std::env::var("DEEPSEEK_API_KEY").expect("DEEPSEEK_API_KEY not set"); @@ -144,11 +190,11 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Creates a DeepSeek completion model with the given `model_name`. - fn completion_model(&self, model_name: &str) -> CompletionModel { + fn completion_model(&self, model_name: &str) -> CompletionModel { CompletionModel { client: self.clone(), model: model_name.to_string(), @@ -156,19 +202,27 @@ impl CompletionClient for Client { } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/user/balance").send().await?; + let req = self + .get("/user/balance")? + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = self.send(req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::SERVICE_UNAVAILABLE => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + // TODO: `HttpClientExt` equivalent + //response.error_for_status()?; Ok(()) } } @@ -179,7 +233,7 @@ impl_conversion_traits!( AsEmbeddings, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] @@ -512,12 +566,12 @@ impl TryFrom for completion::CompletionResponse { + pub client: Client, pub model: String, } -impl CompletionModel { +impl CompletionModel { fn create_completion_request( &self, completion_request: CompletionRequest, @@ -573,7 +627,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -591,13 +645,17 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/chat/completions") + .reqwest_post("/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| http_client::Error::Instance(e.into()))?; if response.status().is_success() { - let t = response.text().await?; + let t: String = response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?; tracing::debug!(target: "rig", "DeepSeek completion: {}", t); match serde_json::from_str::>(&t)? { @@ -605,7 +663,12 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )) } } @@ -624,7 +687,7 @@ impl completion::CompletionModel for CompletionModel { json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.reqwest_post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await } } diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index ad6951930..472c8b46b 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -14,6 +14,7 @@ use super::openai; use crate::client::{ ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, }; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::message::MessageError; use crate::providers::openai::send_compatible_streaming_request; @@ -23,6 +24,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, impl_conversion_traits, json_utils, message, }; +use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; @@ -31,23 +33,28 @@ use serde_json::{Value, json}; // ================================================================ const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, fine_tune_api_key: Option<&'a str>, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, fine_tune_api_key: None, base_url: GALADRIEL_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn fine_tune_api_key(mut self, fine_tune_api_key: &'a str) -> Self { self.fine_tune_api_key = Some(fine_tune_api_key); self @@ -58,35 +65,36 @@ impl<'a> ClientBuilder<'a> { self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + fine_tune_api_key: self.fine_tune_api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), fine_tune_api_key: self.fine_tune_api_key.map(|x| x.to_string()), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, fine_tune_api_key: Option, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -97,7 +105,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Galadriel client builder. /// /// # Example @@ -108,7 +119,7 @@ impl Client { /// let galadriel = Client::builder("your-galadriel-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -117,24 +128,55 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Galadriel client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let mut client = self.http_client.post(url).bearer_auth(&self.api_key); + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + let mut req = http_client::Request::post(url).header("Authorization", auth_header); if let Some(fine_tune_key) = self.fine_tune_api_key.clone() { - client = client.header("Fine-Tune-Authorization", fine_tune_key); + req = req.header("Fine-Tune-Authorization", fine_tune_key); } - client + Ok(req) + } + + async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + R: From + Send, + { + self.http_client.request(req).await } } -impl ProviderClient for Client { +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let mut req = self.http_client.post(url).bearer_auth(&self.api_key); + + if let Some(fine_tune_key) = self.fine_tune_api_key.clone() { + req = req.header("Fine-Tune-Authorization", fine_tune_key) + } + + req + } +} + +impl ProviderClient for Client { /// Create a new Galadriel client from the `GALADRIEL_API_KEY` environment variable, /// and optionally from the `GALADRIEL_FINE_TUNE_API_KEY` environment variable. /// Panics if the `GALADRIEL_API_KEY` environment variable is not set. @@ -145,7 +187,7 @@ impl ProviderClient for Client { if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() { builder = builder.fine_tune_api_key(fine_tune_api_key); } - builder.build().expect("Galadriel client should build") + builder.build() } fn from_val(input: crate::client::ProviderValue) -> Self { @@ -157,12 +199,12 @@ impl ProviderClient for Client { if let Some(fine_tune_key) = fine_tune_key.as_deref() { builder = builder.fine_tune_api_key(fine_tune_key); } - builder.build().expect("Galadriel client should build") + builder.build() } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -175,12 +217,12 @@ impl CompletionClient for Client { /// /// let gpt4 = galadriel.completion_model(galadriel::GPT_4); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { // Could not find an API endpoint to verify the API key @@ -192,7 +234,7 @@ impl_conversion_traits!( AsEmbeddings, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] @@ -466,13 +508,23 @@ pub struct Function { } #[derive(Clone)] -pub struct CompletionModel { - client: Client, +pub struct CompletionModel { + client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl CompletionModel { +impl CompletionModel +where + T: HttpClientExt, +{ + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } + pub(crate) fn create_completion_request( &self, completion_request: CompletionRequest, @@ -528,16 +580,7 @@ impl CompletionModel { } } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { - Self { - client, - model: model.to_string(), - } - } -} - -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -546,20 +589,23 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(completion_request)?; + let body = self.create_completion_request(completion_request)?; + let body = serde_json::to_vec(&body)?; - let response = self + let req = self .client - .post("/chat/completions") - .json(&request) - .send() - .await?; + .post("/chat/completions")? + .body(body) + .map_err(http_client::Error::from)?; + + let response = self.client.send(req).await?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "Galadriel completion error: {}", t); + let text = http_client::text(response).await?; + + tracing::debug!(target: "rig", "Galadriel completion error: {}", text); - match serde_json::from_str::>(&t)? { + match serde_json::from_str::>(&text)? { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "Galadriel completion token usage: {:?}", @@ -570,7 +616,8 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) } } @@ -586,7 +633,7 @@ impl completion::CompletionModel for CompletionModel { json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.reqwest_post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 0d2e6eb22..d481f0874 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -12,10 +12,9 @@ use reqwest_eventsource::{Event, RequestBuilderExt}; use std::collections::HashMap; use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage}; -use crate::client::{ - ClientBuilderError, CompletionClient, TranscriptionClient, VerifyClient, VerifyError, -}; +use crate::client::{CompletionClient, TranscriptionClient, VerifyClient, VerifyError}; use crate::completion::GetTokenUsage; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use async_stream::stream; use futures::StreamExt; @@ -40,54 +39,59 @@ use serde_json::{Value, json}; // ================================================================ const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: GROQ_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -97,7 +101,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Groq client builder. /// /// # Example @@ -108,7 +115,7 @@ impl Client { /// let groq = Client::builder("your-groq-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -117,23 +124,48 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Groq client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + fn req( + &self, + method: http_client::Method, + path: &str, + ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + Ok(http_client::Builder::new() + .method(method) + .uri(url) + .header("Authorization", auth_header)) + } + + fn post(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + fn get(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::GET, path) + } +} + +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + + self.http_client.post(url).bearer_auth(&self.api_key) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Groq client from the `GROQ_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -149,8 +181,8 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -163,13 +195,13 @@ impl CompletionClient for Client { /// /// let gpt4 = groq.completion_model(groq::GPT_4); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl TranscriptionClient for Client { - type TranscriptionModel = TranscriptionModel; +impl TranscriptionClient for Client { + type TranscriptionModel = TranscriptionModel; /// Create a transcription model with the given name. /// @@ -182,25 +214,32 @@ impl TranscriptionClient for Client { /// /// let gpt4 = groq.transcription_model(groq::WHISPER_LARGE_V3); /// ``` - fn transcription_model(&self, model: &str) -> TranscriptionModel { + fn transcription_model(&self, model: &str) -> TranscriptionModel { TranscriptionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/models").send().await?; + let req = self + .get("/models")? + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::SERVICE_UNAVAILABLE | reqwest::StatusCode::BAD_GATEWAY => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } @@ -210,7 +249,7 @@ impl VerifyClient for Client { impl_conversion_traits!( AsEmbeddings, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] @@ -355,14 +394,14 @@ pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192"; pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768"; #[derive(Clone, Debug)] -pub struct CompletionModel { - client: Client, +pub struct CompletionModel { + client: Client, /// Name of the model (e.g.: deepseek-r1-distill-llama-70b) pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -427,7 +466,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -440,13 +479,18 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/chat/completions") + .reqwest_post("/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "groq completion token usage: {:?}", @@ -457,7 +501,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } @@ -476,7 +524,7 @@ impl completion::CompletionModel for CompletionModel { json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.reqwest_post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await } @@ -490,21 +538,21 @@ pub const WHISPER_LARGE_V3_TURBO: &str = "whisper-large-v3-turbo"; pub const DISTIL_WHISPER_LARGE_V3: &str = "distil-whisper-large-v3-en"; #[derive(Clone)] -pub struct TranscriptionModel { - client: Client, +pub struct TranscriptionModel { + client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, } -impl TranscriptionModel { - pub fn new(client: Client, model: &str) -> Self { +impl TranscriptionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), } } } -impl transcription::TranscriptionModel for TranscriptionModel { +impl transcription::TranscriptionModel for TranscriptionModel { type Response = TranscriptionResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -544,23 +592,30 @@ impl transcription::TranscriptionModel for TranscriptionModel { let response = self .client - .post("audio/transcriptions") + .reqwest_post("audio/transcriptions") .multipart(body) .send() - .await?; + .await + .map_err(|e| TranscriptionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { match response .json::>() - .await? - { + .await + .map_err(|e| { + TranscriptionError::HttpError(http_client::Error::Instance(e.into())) + })? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(api_error_response) => Err(TranscriptionError::ProviderError( api_error_response.message, )), } } else { - Err(TranscriptionError::ProviderError(response.text().await?)) + Err(TranscriptionError::ProviderError( + response.text().await.map_err(|e| { + TranscriptionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } } diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index c047c7bbd..78f316bdb 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -13,6 +13,7 @@ use super::openai::{AssistantContent, send_compatible_streaming_request}; use crate::client::{ ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, }; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge_inplace; use crate::message; use crate::streaming::StreamingCompletionResponse; @@ -33,54 +34,59 @@ use serde_json::{Value, json}; // ================================================================ const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: HYPERBOLIC_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -90,7 +96,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Hyperbolic client builder. /// /// # Example @@ -101,7 +110,7 @@ impl Client { /// let hyperbolic = Client::builder("your-hyperbolic-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -110,23 +119,48 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Hyperbolic client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + fn req( + &self, + method: http_client::Method, + path: &str, + ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + Ok(http_client::Builder::new() + .method(method) + .uri(url) + .header("Authorization", auth_header)) + } + + fn post(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + fn get(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::GET, path) + } +} + +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + + self.http_client.post(url).bearer_auth(&self.api_key) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Hyperbolic client from the `HYPERBOLIC_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -142,8 +176,8 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -156,23 +190,32 @@ impl CompletionClient for Client { /// /// let llama_3_1_8b = hyperbolic.completion_model(hyperbolic::LLAMA_3_1_8B); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/users/me").send().await?; + let req = self + .get("/models")? + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), - reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + reqwest::StatusCode::INTERNAL_SERVER_ERROR + | reqwest::StatusCode::SERVICE_UNAVAILABLE + | reqwest::StatusCode::BAD_GATEWAY => { + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } @@ -181,7 +224,7 @@ impl VerifyClient for Client { impl_conversion_traits!( AsEmbeddings, - AsTranscription for Client + AsTranscription for Client ); #[derive(Debug, Deserialize)] @@ -341,13 +384,20 @@ pub struct Choice { } #[derive(Clone)] -pub struct CompletionModel { - client: Client, +pub struct CompletionModel { + client: Client, /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) pub model: String, } -impl CompletionModel { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } + pub(crate) fn create_completion_request( &self, completion_request: CompletionRequest, @@ -391,16 +441,7 @@ impl CompletionModel { } } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { - Self { - client, - model: model.to_string(), - } - } -} - -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -413,13 +454,18 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/v1/chat/completions") + .reqwest_post("/v1/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "Hyperbolic completion token usage: {:?}", @@ -431,7 +477,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } @@ -447,7 +497,10 @@ impl completion::CompletionModel for CompletionModel { json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post("/v1/chat/completions").json(&request); + let builder = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request); send_compatible_streaming_request(builder).await } @@ -465,9 +518,9 @@ pub use image_generation::*; mod image_generation { use super::{ApiResponse, Client}; use crate::client::ImageGenerationClient; - use crate::image_generation; use crate::image_generation::{ImageGenerationError, ImageGenerationRequest}; use crate::json_utils::merge_inplace; + use crate::{http_client, image_generation}; use base64::Engine; use base64::prelude::BASE64_STANDARD; use serde::Deserialize; @@ -482,13 +535,13 @@ mod image_generation { pub const SD1_5_CONTROLNET: &str = "SD1.5-ControlNet"; #[derive(Clone)] - pub struct ImageGenerationModel { - client: Client, + pub struct ImageGenerationModel { + client: Client, pub model: String, } - impl ImageGenerationModel { - pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel { + impl ImageGenerationModel { + pub(crate) fn new(client: Client, model: &str) -> ImageGenerationModel { Self { client, model: model.to_string(), @@ -523,7 +576,7 @@ mod image_generation { } } - impl image_generation::ImageGenerationModel for ImageGenerationModel { + impl image_generation::ImageGenerationModel for ImageGenerationModel { type Response = ImageGenerationResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -545,31 +598,38 @@ mod image_generation { let response = self .client - .post("/v1/image/generation") + .reqwest_post("/v1/image/generation") .json(&request) .send() - .await?; + .await + .map_err(|e| { + ImageGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?; if !response.status().is_success() { return Err(ImageGenerationError::ProviderError(format!( "{}: {}", response.status().as_str(), - response.text().await? + response.text().await.map_err(|e| { + ImageGenerationError::HttpError(http_client::Error::Instance(e.into())) + })? ))); } match response .json::>() - .await? - { + .await + .map_err(|e| { + ImageGenerationError::HttpError(http_client::Error::Instance(e.into())) + })? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(err) => Err(ImageGenerationError::ResponseError(err.message)), } } } - impl ImageGenerationClient for Client { - type ImageGenerationModel = ImageGenerationModel; + impl ImageGenerationClient for Client { + type ImageGenerationModel = ImageGenerationModel; /// Create an image generation model with the given name. /// @@ -582,7 +642,7 @@ mod image_generation { /// /// let llama_3_1_8b = hyperbolic.image_generation_model(hyperbolic::SSD); /// ``` - fn image_generation_model(&self, model: &str) -> ImageGenerationModel { + fn image_generation_model(&self, model: &str) -> ImageGenerationModel { ImageGenerationModel::new(self.clone(), model) } } @@ -607,13 +667,13 @@ mod audio_generation { use serde_json::json; #[derive(Clone)] - pub struct AudioGenerationModel { - client: Client, + pub struct AudioGenerationModel { + client: Client, pub language: String, } - impl AudioGenerationModel { - pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel { + impl AudioGenerationModel { + pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel { Self { client, language: language.to_string(), @@ -643,7 +703,7 @@ mod audio_generation { } } - impl audio_generation::AudioGenerationModel for AudioGenerationModel { + impl audio_generation::AudioGenerationModel for AudioGenerationModel { type Response = AudioGenerationResponse; #[cfg_attr(feature = "worker", worker::send)] @@ -661,29 +721,36 @@ mod audio_generation { let response = self .client - .post("/v1/audio/generation") + .reqwest_post("/v1/audio/generation") .json(&request) .send() - .await?; + .await + .map_err(|e| { + AudioGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?; if !response.status().is_success() { return Err(AudioGenerationError::ProviderError(format!( "{}: {}", response.status(), - response.text().await? + response.text().await.map_err(|e| { + AudioGenerationError::HttpError(http_client::Error::Instance(e.into())) + })? ))); } match serde_json::from_str::>( - &response.text().await?, + &response.text().await.map_err(|e| { + AudioGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?, )? { ApiResponse::Ok(response) => response.try_into(), ApiResponse::Err(err) => Err(AudioGenerationError::ProviderError(err.message)), } } } - impl AudioGenerationClient for Client { - type AudioGenerationModel = AudioGenerationModel; + impl AudioGenerationClient for Client { + type AudioGenerationModel = AudioGenerationModel; /// Create a completion model with the given name. /// diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index f3849fba8..ed75566dd 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -7,9 +7,8 @@ //! let client = mira::Client::new("YOUR_API_KEY"); //! //! ``` -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, -}; +use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::message::{Document, DocumentSourceKind}; use crate::providers::openai; @@ -35,7 +34,7 @@ pub enum MiraError { #[error("API error: {0}")] ApiError(u16), #[error("Request error: {0}")] - RequestError(#[from] reqwest::Error), + RequestError(#[from] http_client::Error), #[error("UTF-8 error: {0}")] Utf8Error(#[from] FromUtf8Error), #[error("JSON error: {0}")] @@ -111,32 +110,40 @@ struct ModelInfo { id: String, } -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: MIRA_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { + pub fn build(self) -> Client { let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); headers.insert( @@ -147,31 +154,29 @@ impl<'a> ClientBuilder<'a> { reqwest::header::USER_AGENT, HeaderValue::from_static("rig-client/1.0"), ); - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - Ok(Client { + Client { base_url: self.base_url.to_string(), - http_client, + http_client: self.http_client, api_key: self.api_key.to_string(), headers, - }) + } } } #[derive(Clone)] /// Client for interacting with the Mira API -pub struct Client { +pub struct Client { base_url: String, - http_client: reqwest::Client, + http_client: T, api_key: String, headers: HeaderMap, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -182,7 +187,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Mira client builder. /// /// # Example @@ -193,7 +201,7 @@ impl Client { /// let mira = Client::builder("your-mira-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -202,25 +210,33 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Mira client should build") + Self::builder(api_key).build() } +} +impl Client +where + T: HttpClientExt, +{ /// List available models pub async fn list_models(&self) -> Result, MiraError> { - let response = self.get("/v1/models").send().await?; + let req = self.get("/v1/models").and_then(|req| { + req.body(http_client::NoBody) + .map_err(|e| http_client::Error::Protocol(e.into())) + })?; + + let response = self.http_client.request(req).await?; let status = response.status(); if !status.is_success() { // Log the error text but don't store it in an unused variable - let _error_text = response.text().await.unwrap_or_default(); - tracing::error!("Error response: {}", _error_text); + let error_text = http_client::text(response).await.unwrap_or_default(); + tracing::error!("Error response: {}", error_text); return Err(MiraError::ApiError(status.as_u16())); } - let response_text = response.text().await?; + let response_text = http_client::text(response).await?; let models: ModelsResponse = serde_json::from_str(&response_text).map_err(|e| { tracing::error!("Failed to parse response: {}", e); @@ -230,24 +246,49 @@ impl Client { Ok(models.data.into_iter().map(|model| model.id).collect()) } - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { + fn req( + &self, + method: http_client::Method, + path: &str, + ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client - .post(url) - .bearer_auth(&self.api_key) - .headers(self.headers.clone()) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + let mut req = http_client::Builder::new() + .method(method) + .uri(url) + .header("Authorization", auth_header); + + if let Some(hs) = req.headers_mut() { + *hs = self.headers.clone(); + } + + Ok(req) + } + + pub(crate) fn post(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) + } +} + +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client - .get(url) + .post(url) .bearer_auth(&self.api_key) .headers(self.headers.clone()) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Mira client from the `MIRA_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -263,26 +304,35 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.to_owned(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/user-credits").send().await?; + let req = self + .get("/user-credits")? + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), - reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + reqwest::StatusCode::INTERNAL_SERVER_ERROR + | reqwest::StatusCode::SERVICE_UNAVAILABLE + | reqwest::StatusCode::BAD_GATEWAY => { + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } @@ -293,18 +343,18 @@ impl_conversion_traits!( AsEmbeddings, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Clone)] -pub struct CompletionModel { - client: Client, +pub struct CompletionModel { + client: Client, /// Name of the model pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -392,7 +442,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -412,7 +462,7 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/v1/chat/completions") + .reqwest_post("/v1/chat/completions") .json(&mira_request) .send() .await @@ -443,7 +493,10 @@ impl completion::CompletionModel for CompletionModel { request = merge(request, json!({"stream": true})); - let builder = self.client.post("/v1/chat/completions").json(&request); + let builder = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request); send_compatible_streaming_request(builder).await } diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs index 8c737c6b2..797de2ae5 100644 --- a/rig-core/src/providers/openrouter/client.rs +++ b/rig-core/src/providers/openrouter/client.rs @@ -1,8 +1,10 @@ use crate::{ client::{ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError}, + http_client::{self, HttpClientExt}, impl_conversion_traits, }; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use super::completion::CompletionModel; @@ -11,54 +13,59 @@ use super::completion::CompletionModel; // ================================================================ const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: OPENROUTER_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} - pub fn base_url(mut self, base_url: &'a str) -> Self { - self.base_url = base_url; - self +impl<'a, T> ClientBuilder<'a, T> { + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; self } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl Debug for Client +where + T: Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -68,7 +75,22 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client { + pub(crate) fn reqwest_client(&self) -> &reqwest::Client { + &self.http_client + } + + pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + self.http_client.post(url).bearer_auth(&self.api_key) + } +} + +impl Client +where + T: Default, +{ /// Create a new OpenRouter client builder. /// /// # Example @@ -79,7 +101,7 @@ impl Client { /// let openrouter = Client::builder("your-openrouter-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -88,23 +110,31 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("OpenRouter client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client { + pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + Ok(http_client::Request::post(url).header("Authorization", auth_header)) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + Ok(http_client::Request::get(url).header("Authorization", auth_header)) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new openrouter client from the `OPENROUTER_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -120,8 +150,8 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -134,23 +164,30 @@ impl CompletionClient for Client { /// /// let llama_3_1_8b = openrouter.completion_model(openrouter::LLAMA_3_1_8B); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/key").send().await?; + let req = self + .get("/key")? + .body(http_client::NoBody) + .map_err(|e| VerifyError::HttpError(e.into()))?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } @@ -161,7 +198,7 @@ impl_conversion_traits!( AsEmbeddings, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] diff --git a/rig-core/src/providers/openrouter/completion.rs b/rig-core/src/providers/openrouter/completion.rs index 83cd6a2ed..666771926 100644 --- a/rig-core/src/providers/openrouter/completion.rs +++ b/rig-core/src/providers/openrouter/completion.rs @@ -5,7 +5,7 @@ use super::client::{ApiErrorResponse, ApiResponse, Client, Usage}; use crate::{ OneOrMany, completion::{self, CompletionError, CompletionRequest}, - json_utils, + http_client, json_utils, providers::openai::Message, }; use serde_json::{Value, json}; @@ -122,14 +122,14 @@ pub struct Choice { } #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -182,7 +182,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = FinalCompletionResponse; @@ -195,13 +195,19 @@ impl completion::CompletionModel for CompletionModel { let response = self .client + .reqwest_client() .post("/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenRouter completion token usage: {:?}", @@ -214,7 +220,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } diff --git a/rig-core/src/providers/openrouter/streaming.rs b/rig-core/src/providers/openrouter/streaming.rs index 2cb66b3fc..09e5524a0 100644 --- a/rig-core/src/providers/openrouter/streaming.rs +++ b/rig-core/src/providers/openrouter/streaming.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use crate::{ completion::GetTokenUsage, - json_utils, + http_client, json_utils, message::{ToolCall, ToolFunction}, streaming::{self}, }; @@ -112,7 +112,7 @@ pub struct FinalCompletionResponse { pub usage: ResponseUsage, } -impl super::CompletionModel { +impl super::CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, @@ -122,7 +122,7 @@ impl super::CompletionModel { let request = json_utils::merge(request, json!({"stream": true})); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.reqwest_post("/chat/completions").json(&request); send_streaming_request(builder).await } @@ -131,13 +131,19 @@ impl super::CompletionModel { pub async fn send_streaming_request( request_builder: RequestBuilder, ) -> Result, CompletionError> { - let response = request_builder.send().await?; + let response = request_builder + .send() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if !response.status().is_success() { return Err(CompletionError::ProviderError(format!( "{}: {}", response.status(), - response.text().await? + response + .text() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? ))); } @@ -152,7 +158,7 @@ pub async fn send_streaming_request( let chunk = match chunk_result { Ok(c) => c, Err(e) => { - yield Err(CompletionError::from(e)); + yield Err(CompletionError::from(http_client::Error::Instance(e.into()))); break; } }; diff --git a/rig-core/src/providers/together/client.rs b/rig-core/src/providers/together/client.rs index 4d36bbd66..647359d5a 100644 --- a/rig-core/src/providers/together/client.rs +++ b/rig-core/src/providers/together/client.rs @@ -1,8 +1,9 @@ use super::{M2_BERT_80M_8K_RETRIEVAL, completion::CompletionModel, embedding::EmbeddingModel}; -use crate::client::{ - ClientBuilderError, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, - impl_conversion_traits, +use crate::{ + client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits}, + http_client::{self, HttpClientExt}, }; +use bytes::Bytes; use rig::client::CompletionClient; // ================================================================ @@ -10,61 +11,66 @@ use rig::client::CompletionClient; // ================================================================ const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: TOGETHER_AI_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { + pub fn build(self) -> Client { let mut default_headers = reqwest::header::HeaderMap::new(); default_headers.insert( reqwest::header::CONTENT_TYPE, "application/json".parse().unwrap(), ); - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), default_headers, - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, default_headers: reqwest::header::HeaderMap, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -75,7 +81,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Together AI client builder. /// /// # Example @@ -86,7 +95,7 @@ impl Client { /// let together_ai = Client::builder("your-together-ai-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -95,33 +104,77 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Together AI client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); tracing::debug!("POST {}", url); - self.http_client - .post(url) - .bearer_auth(&self.api_key) - .headers(self.default_headers.clone()) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + let mut req = http_client::Request::post(url).header("Authorization", auth_header); + + if let Some(hs) = req.headers_mut() { + *hs = self.default_headers.clone(); + } + + Ok(req) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); tracing::debug!("GET {}", url); + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + let mut req = http_client::Request::get(url).header("Authorization", auth_header); + + if let Some(hs) = req.headers_mut() { + *hs = self.default_headers.clone(); + } + + Ok(req) + } + + pub(crate) async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into, + R: From + Send, + { + self.http_client.request(req).await + } +} + +impl Client { + pub(crate) fn reqwest_client(&self) -> &reqwest::Client { + &self.http_client + } + + pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + tracing::debug!("POST {}", url); + self.http_client - .get(url) + .post(url) .bearer_auth(&self.api_key) .headers(self.default_headers.clone()) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Together AI client from the `TOGETHER_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -137,17 +190,17 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; /// Create an embedding model with the given name. /// Note: default embedding dimension of 0 will be used if model is not known. @@ -162,7 +215,7 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = together_ai.embedding_model(together_ai::embedding::EMBEDDING_V1); /// ``` - fn embedding_model(&self, model: &str) -> EmbeddingModel { + fn embedding_model(&self, model: &str) -> EmbeddingModel { let ndims = match model { M2_BERT_80M_8K_RETRIEVAL => 8192, _ => 0, @@ -182,30 +235,41 @@ impl EmbeddingsClient for Client { /// /// let embedding_model = together_ai.embedding_model_with_ndims("model-unknown-to-rig", 1024); /// ``` - fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + fn embedding_model_with_ndims( + &self, + model: &str, + ndims: usize, + ) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, ndims) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/models").send().await?; + let req = self + .get("/models")? + .body(http_client::NoBody) + .map_err(|e| VerifyError::HttpError(e.into()))?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::GATEWAY_TIMEOUT => { - Err(VerifyError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } } } -impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client); +impl_conversion_traits!(AsTranscription, AsImageGeneration, AsAudioGeneration for Client); pub mod together_ai_api_types { use serde::Deserialize; diff --git a/rig-core/src/providers/together/completion.rs b/rig-core/src/providers/together/completion.rs index fd0e0f54a..a07445d47 100644 --- a/rig-core/src/providers/together/completion.rs +++ b/rig-core/src/providers/together/completion.rs @@ -5,7 +5,7 @@ use crate::{ completion::{self, CompletionError}, - json_utils, + http_client, json_utils, providers::openai, }; @@ -128,13 +128,13 @@ pub const WIZARDLM_13B_V1_2: &str = "WizardLM/WizardLM-13B-V1.2"; // ================================================================= #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -188,7 +188,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = openai::CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -201,16 +201,21 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/v1/chat/completions") + .reqwest_post("/v1/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "Together completion error: {}", t); + let text = response + .text() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - match serde_json::from_str::>(&t)? { + tracing::debug!(target: "rig", "Together completion error: {}", text); + + match serde_json::from_str::>(&text)? { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "Together completion token usage: {:?}", @@ -221,7 +226,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Error(err) => Err(CompletionError::ProviderError(err.error)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } diff --git a/rig-core/src/providers/together/embedding.rs b/rig-core/src/providers/together/embedding.rs index 1a79e4384..9efb8a4f9 100644 --- a/rig-core/src/providers/together/embedding.rs +++ b/rig-core/src/providers/together/embedding.rs @@ -6,7 +6,10 @@ use serde::Deserialize; use serde_json::json; -use crate::embeddings::{self, EmbeddingError}; +use crate::{ + embeddings::{self, EmbeddingError}, + http_client::{self, HttpClientExt}, +}; use super::{ Client, @@ -63,13 +66,16 @@ pub struct Usage { } #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, pub model: String, ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Default + Clone + Send + 'static, +{ const MAX_DOCUMENTS: usize = 1024; // This might need to be adjusted based on Together AI's actual limit fn ndims(&self) -> usize { @@ -83,18 +89,24 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ) -> Result, EmbeddingError> { let documents = documents.into_iter().collect::>(); - let response = self + let body = serde_json::to_vec(&json!({ + "model": self.model, + "input": documents, + }))?; + + let req = self .client - .post("/v1/embeddings") - .json(&json!({ - "model": self.model, - "input": documents, - })) - .send() - .await?; + .post("/v1/embeddings")? + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; if response.status().is_success() { - match response.json::>().await? { + let body: Vec = response.into_body().await?; + let body: ApiResponse = serde_json::from_slice(&body)?; + + match body { ApiResponse::Ok(response) => { if response.data.len() != documents.len() { return Err(EmbeddingError::ResponseError( @@ -115,13 +127,17 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + let text = http_client::text(response).await?; + Err(EmbeddingError::ProviderError(text)) } } } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: usize) -> Self { +impl EmbeddingModel +where + T: Default, +{ + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), diff --git a/rig-core/src/providers/together/streaming.rs b/rig-core/src/providers/together/streaming.rs index f4bf9bd5b..08a54a1f2 100644 --- a/rig-core/src/providers/together/streaming.rs +++ b/rig-core/src/providers/together/streaming.rs @@ -9,7 +9,7 @@ use crate::{ json_utils::merge, }; -impl CompletionModel { +impl CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, @@ -19,7 +19,10 @@ impl CompletionModel { request = merge(request, json!({"stream_tokens": true})); - let builder = self.client.post("/v1/chat/completions").json(&request); + let builder = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request); send_compatible_streaming_request(builder).await } diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs index e9eb0d4b4..ba970eed6 100644 --- a/rig-core/src/providers/xai/client.rs +++ b/rig-core/src/providers/xai/client.rs @@ -1,7 +1,10 @@ use super::completion::CompletionModel; -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, - impl_conversion_traits, +use crate::{ + client::{ + ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, + impl_conversion_traits, + }, + http_client, }; // ================================================================ @@ -9,62 +12,67 @@ use crate::client::{ // ================================================================ const XAI_BASE_URL: &str = "https://api.x.ai"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: XAI_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { + pub fn build(self) -> Client { let mut default_headers = reqwest::header::HeaderMap::new(); default_headers.insert( reqwest::header::CONTENT_TYPE, "application/json".parse().unwrap(), ); - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), default_headers, - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - default_headers: reqwest::header::HeaderMap, - http_client: reqwest::Client, + default_headers: http_client::HeaderMap, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -75,7 +83,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new xAI client builder. /// /// # Example @@ -86,7 +97,7 @@ impl Client { /// let xai = Client::builder("your-xai-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -95,25 +106,63 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("xAI client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client { + fn req( + &self, + method: http_client::Method, + url: &str, + ) -> http_client::Result { + let mut request = http_client::Builder::new().method(method).uri(url); + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", self.api_key)) + .map_err(|e| http_client::Error::Protocol(e.into()))?; + + if let Some(hs) = request.headers_mut() { + *hs = self.default_headers.clone(); + hs.insert("Authorization", auth_header); + } + + Ok(request) + } + + pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); tracing::debug!("POST {}", url); + + self.req(http_client::Method::POST, &url) + } + + pub(crate) fn get(&self, path: &str) -> http_client::Result { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + tracing::debug!("GET {}", url); + + self.req(http_client::Method::GET, &url) + } +} + +impl Client { + pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + tracing::debug!("POST {}", url); + self.http_client .post(url) .bearer_auth(&self.api_key) .headers(self.default_headers.clone()) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn reqwest_get(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - tracing::debug!("GET {}", url); + tracing::debug!("POST {}", url); + self.http_client .get(url) .bearer_auth(&self.api_key) @@ -121,7 +170,7 @@ impl Client { } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new xAI client from the `XAI_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -137,29 +186,38 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/v1/api-key").send().await?; + let response = self + .reqwest_get("/v1/api-key") + .send() + .await + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => { Err(VerifyError::InvalidAuthentication) } reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + Err(VerifyError::ProviderError(response.text().await.map_err( + |e| VerifyError::HttpError(http_client::Error::Instance(e.into())), + )?)) } _ => { - response.error_for_status()?; + response + .error_for_status() + .map_err(|e| VerifyError::HttpError(http_client::Error::Instance(e.into())))?; Ok(()) } } @@ -170,7 +228,7 @@ impl_conversion_traits!( AsEmbeddings, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); pub mod xai_api_types { diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index d9b296995..be9922165 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -5,7 +5,7 @@ use crate::{ completion::{self, CompletionError}, - json_utils, + http_client, json_utils, providers::openai::Message, }; @@ -31,12 +31,12 @@ pub const GROK_4: &str = "grok-4-0709"; // ================================================================= #[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, +pub struct CompletionModel { + pub(crate) client: Client, pub model: String, } -impl CompletionModel { +impl CompletionModel { pub(crate) fn create_completion_request( &self, completion_request: completion::CompletionRequest, @@ -95,7 +95,8 @@ impl CompletionModel { Ok(request) } - pub fn new(client: Client, model: &str) -> Self { + + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -103,7 +104,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -116,18 +117,28 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/v1/chat/completions") + .reqwest_post("/v1/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + let body = response + .json::>() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + + match body { ApiResponse::Ok(completion) => completion.try_into(), ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } diff --git a/rig-core/src/providers/xai/streaming.rs b/rig-core/src/providers/xai/streaming.rs index cf353f01a..8ba32cf17 100644 --- a/rig-core/src/providers/xai/streaming.rs +++ b/rig-core/src/providers/xai/streaming.rs @@ -6,7 +6,7 @@ use crate::providers::xai::completion::CompletionModel; use crate::streaming::StreamingCompletionResponse; use serde_json::json; -impl CompletionModel { +impl CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, @@ -16,7 +16,10 @@ impl CompletionModel { request = merge(request, json!({"stream": true})); - let builder = self.client.post("/v1/chat/completions").json(&request); + let builder = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request); send_compatible_streaming_request(builder).await } From ad821acc00859edac45d9dd5b0a84d68b40ba3e9 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 2 Oct 2025 18:04:31 -0400 Subject: [PATCH 05/20] ALL CLIENTS DONE ! --- rig-core/src/providers/moonshot.rs | 150 ++++++++++++------- rig-core/src/providers/ollama.rs | 207 +++++++++++++++++---------- rig-core/src/providers/perplexity.rs | 115 +++++++++------ rig-core/src/providers/voyageai.rs | 114 +++++++++------ 4 files changed, 376 insertions(+), 210 deletions(-) diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index c1f454d06..d6128ca14 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -11,6 +11,7 @@ use crate::client::{ ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, }; +use crate::http_client::HttpClientExt; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; use crate::streaming::StreamingCompletionResponse; @@ -19,7 +20,7 @@ use crate::{ json_utils, providers::openai, }; -use crate::{impl_conversion_traits, message}; +use crate::{http_client, impl_conversion_traits, message}; use serde::Deserialize; use serde_json::{Value, json}; @@ -28,54 +29,59 @@ use serde_json::{Value, json}; // ================================================================ const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: MOONSHOT_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -85,7 +91,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Moonshot client builder. /// /// # Example @@ -96,7 +105,7 @@ impl Client { /// let moonshot = Client::builder("your-moonshot-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -105,23 +114,47 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Moonshot client should build") + Self::builder(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client +where + T: HttpClientExt, +{ + fn req( + &self, + method: http_client::Method, + path: &str, + ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.post(url).bearer_auth(&self.api_key) + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + Ok(http_client::Builder::new() + .method(method) + .uri(url) + .header("Authorization", auth_header)) + } + pub(crate) fn post(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) } - pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder { + pub(crate) fn get(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::GET, path) + } +} + +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - self.http_client.get(url).bearer_auth(&self.api_key) + + self.http_client.post(url).bearer_auth(&self.api_key) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Moonshot client from the `MOONSHOT_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -137,8 +170,8 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -151,23 +184,32 @@ impl CompletionClient for Client { /// /// let completion_model = moonshot.completion_model(moonshot::MOONSHOT_CHAT); /// ``` - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self.get("/models").send().await?; + let req = self + .get("/models")? + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), - reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - Err(VerifyError::ProviderError(response.text().await?)) + reqwest::StatusCode::INTERNAL_SERVER_ERROR + | reqwest::StatusCode::SERVICE_UNAVAILABLE + | reqwest::StatusCode::BAD_GATEWAY => { + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } @@ -178,7 +220,7 @@ impl_conversion_traits!( AsEmbeddings, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] @@ -204,13 +246,13 @@ enum ApiResponse { pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k"; #[derive(Clone)] -pub struct CompletionModel { - client: Client, +pub struct CompletionModel { + client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -272,7 +314,7 @@ impl CompletionModel { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = openai::CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -285,13 +327,17 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/chat/completions") + .reqwest_post("/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - let t = response.text().await?; + let t = response + .text() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; tracing::debug!(target: "rig", "MoonShot completion error: {}", t); match serde_json::from_str::>(&t)? { @@ -305,7 +351,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } @@ -321,7 +371,7 @@ impl completion::CompletionModel for CompletionModel { json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.reqwest_post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 6ec3b7079..e6613ffe8 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -43,6 +43,7 @@ use crate::client::{ VerifyError, }; use crate::completion::{GetTokenUsage, Usage}; +use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge_inplace; use crate::message::DocumentSourceKind; use crate::streaming::RawStreamingChoice; @@ -66,58 +67,64 @@ use url::Url; const OLLAMA_API_BASE_URL: &str = "http://localhost:11434"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ #[allow(clippy::new_without_default)] pub fn new() -> Self { Self { base_url: OLLAMA_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { - base_url: Url::parse(self.base_url) - .map_err(|_| ClientBuilderError::InvalidProperty("base_url"))?, - http_client, - }) + pub fn build(self) -> Client { + Client { + base_url: self.base_url.into(), + http_client: self.http_client, + } } } #[derive(Clone, Debug)] -pub struct Client { - base_url: Url, - http_client: reqwest::Client, +pub struct Client { + base_url: String, + http_client: T, } -impl Default for Client { +impl Default for Client +where + T: Default, +{ fn default() -> Self { Self::new() } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Ollama client builder. /// /// # Example @@ -128,7 +135,7 @@ impl Client { /// let client = Client::builder() /// .build() /// ``` - pub fn builder() -> ClientBuilder<'static> { + pub fn builder<'a>() -> ClientBuilder<'a, T> { ClientBuilder::new() } @@ -137,27 +144,36 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new() -> Self { - Self::builder().build().expect("Ollama client should build") + Self::builder().build() + } +} + +impl Client { + fn req(&self, method: http_client::Method, path: &str) -> http_client::Builder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + http_client::Builder::new().method(method).uri(url) } - pub(crate) fn post(&self, path: &str) -> Result { - let url = self.base_url.join(path)?; - Ok(self.http_client.post(url)) + pub(crate) fn post(&self, path: &str) -> http_client::Builder { + self.req(http_client::Method::POST, path) } - pub(crate) fn get(&self, path: &str) -> Result { - let url = self.base_url.join(path)?; - Ok(self.http_client.get(url)) + pub(crate) fn get(&self, path: &str) -> http_client::Builder { + self.req(http_client::Method::GET, path) } } -impl ProviderClient for Client { - fn from_env() -> Self - where - Self: Sized, - { +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } +} + +impl ProviderClient for Client { + fn from_env() -> Self { let api_base = std::env::var("OLLAMA_API_BASE_URL").expect("OLLAMA_API_BASE_URL not set"); - Self::builder().base_url(&api_base).build().unwrap() + Self::builder().base_url(&api_base).build() } fn from_val(input: crate::client::ProviderValue) -> Self { @@ -169,39 +185,52 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; - fn embedding_model(&self, model: &str) -> EmbeddingModel { +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; + fn embedding_model(&self, model: &str) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, 0) } - fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + fn embedding_model_with_ndims( + &self, + model: &str, + ndims: usize, + ) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, ndims) } - fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + fn embeddings(&self, model: &str) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model)) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response = self + let req = self .get("api/tags") - .expect("Failed to build request") - .send() - .await?; + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = HttpClientExt::request(&self.http_client, req).await?; + match response.status() { reqwest::StatusCode::OK => Ok(()), + reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), + reqwest::StatusCode::INTERNAL_SERVER_ERROR + | reqwest::StatusCode::SERVICE_UNAVAILABLE + | reqwest::StatusCode::BAD_GATEWAY => { + let text = http_client::text(response).await?; + Err(VerifyError::ProviderError(text)) + } _ => { - response.error_for_status()?; + //response.error_for_status()?; Ok(()) } } @@ -211,7 +240,7 @@ impl VerifyClient for Client { impl_conversion_traits!( AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); // ---------- API Error and Response Structures ---------- @@ -263,14 +292,14 @@ impl From> for Result { + client: Client, pub model: String, ndims: usize, } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: usize) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_owned(), @@ -279,7 +308,7 @@ impl EmbeddingModel { } } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { self.ndims @@ -290,17 +319,26 @@ impl embeddings::EmbeddingModel for EmbeddingModel { documents: impl IntoIterator, ) -> Result, EmbeddingError> { let docs: Vec = documents.into_iter().collect(); - let payload = json!({ + + let body = serde_json::to_vec(&json!({ "model": self.model, - "input": docs, - }); - let response = self.client.post("api/embed")?.json(&payload).send().await?; + "input": docs + }))?; + + let req = self + .client + .post("api/embed") + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + + let response = HttpClientExt::request(&self.client.http_client, req).await?; if !response.status().is_success() { - return Err(EmbeddingError::ProviderError(response.text().await?)); + let text = http_client::text(response).await?; + return Err(EmbeddingError::ProviderError(text)); } - let bytes = response.bytes().await?; + let bytes: Vec = response.into_body().await?; let api_resp: EmbeddingResponse = serde_json::from_slice(&bytes)?; @@ -416,13 +454,13 @@ impl TryFrom for completion::CompletionResponse { + client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_owned(), @@ -514,7 +552,7 @@ impl GetTokenUsage for StreamingCompletionResponse { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; @@ -527,16 +565,24 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("api/chat")? + .reqwest_post("api/chat") .json(&request_payload) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if !response.status().is_success() { - return Err(CompletionError::ProviderError(response.text().await?)); + return Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )); } - let bytes = response.bytes().await?; + let bytes = response + .bytes() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; tracing::debug!(target: "rig", "Received response from Ollama: {}", String::from_utf8_lossy(&bytes)); @@ -558,20 +604,27 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("api/chat")? + .reqwest_post("api/chat") .json(&request_payload) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if !response.status().is_success() { - return Err(CompletionError::ProviderError(response.text().await?)); + return Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )); } let stream = Box::pin(try_stream! { let mut byte_stream = response.bytes_stream(); while let Some(chunk) = byte_stream.next().await { - let bytes = chunk?; + let bytes = chunk.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?; for line in bytes.split(|&b| b == b'\n') { if line.is_empty() { diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index f17914aff..dde78dbd1 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -12,6 +12,7 @@ use crate::{ OneOrMany, client::{VerifyClient, VerifyError}, completion::{self, CompletionError, MessageError, message}, + http_client::{self, HttpClientExt}, impl_conversion_traits, json_utils, }; @@ -29,54 +30,59 @@ use serde_json::{Value, json}; // ================================================================ const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: PERPLEXITY_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -86,7 +92,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Perplexity client builder. /// /// # Example @@ -97,7 +106,7 @@ impl Client { /// let perplexity = Client::builder("your-perplexity-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -106,18 +115,31 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Perplexity client should build") + Self::builder(api_key).build() + } +} + +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + Ok(http_client::Request::post(url).header("Authorization", auth_header)) } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client { + fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); self.http_client.post(url).bearer_auth(&self.api_key) } } -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new Perplexity client from the `PERPLEXITY_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -133,15 +155,15 @@ impl ProviderClient for Client { } } -impl CompletionClient for Client { - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { // No API endpoint to verify the API key @@ -153,7 +175,7 @@ impl_conversion_traits!( AsTranscription, AsEmbeddings, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); #[derive(Debug, Deserialize)] @@ -261,13 +283,13 @@ impl TryFrom for completion::CompletionResponse { + client: Client, pub model: String, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -381,7 +403,7 @@ impl From for message::Message { } } -impl completion::CompletionModel for CompletionModel { +impl completion::CompletionModel for CompletionModel { type Response = CompletionResponse; type StreamingResponse = openai::StreamingCompletionResponse; @@ -394,13 +416,18 @@ impl completion::CompletionModel for CompletionModel { let response = self .client - .post("/chat/completions") + .reqwest_post("/chat/completions") .json(&request) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(completion) => { tracing::info!(target: "rig", "Perplexity completion token usage: {}", @@ -411,7 +438,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } @@ -424,7 +455,7 @@ impl completion::CompletionModel for CompletionModel { request = merge(request, json!({"stream": true})); - let builder = self.client.post("/chat/completions").json(&request); + let builder = self.client.reqwest_post("/chat/completions").json(&request); send_compatible_streaming_request(builder).await } diff --git a/rig-core/src/providers/voyageai.rs b/rig-core/src/providers/voyageai.rs index f20becd5e..dd5630913 100644 --- a/rig-core/src/providers/voyageai.rs +++ b/rig-core/src/providers/voyageai.rs @@ -2,7 +2,8 @@ use crate::client::{ ClientBuilderError, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, }; use crate::embeddings::EmbeddingError; -use crate::{embeddings, impl_conversion_traits}; +use crate::http_client::HttpClientExt; +use crate::{embeddings, http_client, impl_conversion_traits}; use serde::Deserialize; use serde_json::json; @@ -11,54 +12,59 @@ use serde_json::json; // ================================================================ const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1"; -pub struct ClientBuilder<'a> { +pub struct ClientBuilder<'a, T> { api_key: &'a str, base_url: &'a str, - http_client: Option, + http_client: T, } -impl<'a> ClientBuilder<'a> { +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ pub fn new(api_key: &'a str) -> Self { Self { api_key, base_url: OPENAI_API_BASE_URL, - http_client: None, + http_client: Default::default(), } } +} +impl<'a, T> ClientBuilder<'a, T> { pub fn base_url(mut self, base_url: &'a str) -> Self { self.base_url = base_url; self } - pub fn custom_client(mut self, client: reqwest::Client) -> Self { - self.http_client = Some(client); - self + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + http_client, + } } - pub fn build(self) -> Result { - let http_client = if let Some(http_client) = self.http_client { - http_client - } else { - reqwest::Client::builder().build()? - }; - - Ok(Client { + pub fn build(self) -> Client { + Client { base_url: self.base_url.to_string(), api_key: self.api_key.to_string(), - http_client, - }) + http_client: self.http_client, + } } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, - http_client: reqwest::Client, + http_client: T, } -impl std::fmt::Debug for Client { +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("base_url", &self.base_url) @@ -68,7 +74,10 @@ impl std::fmt::Debug for Client { } } -impl Client { +impl Client +where + T: Default, +{ /// Create a new Voyage AI client builder. /// /// # Example @@ -79,7 +88,7 @@ impl Client { /// let voyageai = Client::builder("your-voyageai-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_> { + pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { ClientBuilder::new(api_key) } @@ -88,18 +97,32 @@ impl Client { /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). pub fn new(api_key: &str) -> Self { - Self::builder(api_key) - .build() - .expect("Voyage AI client should build") + Self::builder(api_key).build() + } +} + +impl Client +where + T: HttpClientExt, +{ + pub(crate) fn post(&self, path: &str) -> http_client::Result { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) + .map_err(http::Error::from)?; + + Ok(http_client::Request::post(url).header("Authorization", auth_header)) } +} - pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { +impl Client { + pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); self.http_client.post(url).bearer_auth(&self.api_key) } } -impl VerifyClient for Client { +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { // No API endpoint to verify the API key @@ -111,10 +134,10 @@ impl_conversion_traits!( AsCompletion, AsTranscription, AsImageGeneration, - AsAudioGeneration for Client + AsAudioGeneration for Client ); -impl ProviderClient for Client { +impl ProviderClient for Client { /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -132,8 +155,8 @@ impl ProviderClient for Client { /// Although the models have default embedding dimensions, there are additional alternatives for increasing and decreasing the dimensions to your requirements. /// See Voyage AI's documentation: -impl EmbeddingsClient for Client { - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; fn embedding_model(&self, model: &str) -> Self::EmbeddingModel { let ndims = match model { VOYAGE_CODE_2 => 1536, @@ -149,8 +172,8 @@ impl EmbeddingsClient for Client { } } -impl EmbeddingModel { - pub fn new(client: Client, model: &str, ndims: usize) -> Self { +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, model: model.to_string(), @@ -226,13 +249,13 @@ pub struct EmbeddingData { } #[derive(Clone)] -pub struct EmbeddingModel { - client: Client, +pub struct EmbeddingModel { + client: Client, pub model: String, ndims: usize, } -impl embeddings::EmbeddingModel for EmbeddingModel { +impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { @@ -248,16 +271,21 @@ impl embeddings::EmbeddingModel for EmbeddingModel { let response = self .client - .post("/embeddings") + .reqwest_post("/embeddings") .json(&json!({ "model": self.model, "input": documents, })) .send() - .await?; + .await + .map_err(|e| EmbeddingError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| EmbeddingError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "VoyageAI embedding token usage: {}", @@ -283,7 +311,11 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + Err(EmbeddingError::ProviderError( + response.text().await.map_err(|e| { + EmbeddingError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } } From 78d100c19c707ab47dac4327e8cee2a71e2a9091 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Fri, 3 Oct 2025 17:30:46 -0400 Subject: [PATCH 06/20] Cleanup --- rig-core/examples/agent_with_galadriel.rs | 2 +- rig-core/examples/pdf_agent.rs | 2 +- rig-core/examples/vector_search_ollama.rs | 2 +- rig-core/src/client/builder.rs | 72 +++++------ rig-core/src/client/mod.rs | 68 +++++------ rig-core/src/http_client.rs | 48 +++++--- rig-core/src/providers/anthropic/client.rs | 115 ++++++++---------- .../src/providers/anthropic/completion.rs | 8 +- rig-core/src/providers/anthropic/streaming.rs | 2 +- rig-core/src/providers/azure.rs | 54 ++++---- rig-core/src/providers/cohere/client.rs | 72 +++++------ rig-core/src/providers/cohere/completion.rs | 2 +- rig-core/src/providers/cohere/embeddings.rs | 7 +- rig-core/src/providers/cohere/streaming.rs | 2 +- rig-core/src/providers/deepseek.rs | 28 ++--- rig-core/src/providers/galadriel.rs | 19 ++- rig-core/src/providers/gemini/client.rs | 7 +- rig-core/src/providers/gemini/completion.rs | 7 +- rig-core/src/providers/gemini/embedding.rs | 2 +- .../src/providers/gemini/transcription.rs | 2 +- rig-core/src/providers/groq.rs | 19 +-- rig-core/src/providers/huggingface/client.rs | 26 ++-- .../src/providers/huggingface/completion.rs | 3 +- .../providers/huggingface/image_generation.rs | 2 +- .../providers/huggingface/transcription.rs | 4 +- rig-core/src/providers/hyperbolic.rs | 28 ++--- rig-core/src/providers/mira.rs | 20 +-- rig-core/src/providers/mistral/client.rs | 66 +++------- rig-core/src/providers/mistral/completion.rs | 2 +- rig-core/src/providers/mistral/embedding.rs | 2 +- rig-core/src/providers/moonshot.rs | 24 ++-- rig-core/src/providers/ollama.rs | 8 +- .../src/providers/openai/audio_generation.rs | 43 ++++--- rig-core/src/providers/openai/client.rs | 26 ++-- .../src/providers/openai/completion/mod.rs | 2 +- rig-core/src/providers/openai/embedding.rs | 5 +- .../src/providers/openai/image_generation.rs | 2 +- .../src/providers/openai/responses_api/mod.rs | 2 +- .../src/providers/openai/transcription.rs | 4 +- rig-core/src/providers/openrouter/client.rs | 20 +-- .../src/providers/openrouter/completion.rs | 2 +- rig-core/src/providers/perplexity.rs | 22 +--- rig-core/src/providers/together/client.rs | 23 ++-- rig-core/src/providers/together/completion.rs | 2 +- rig-core/src/providers/together/embedding.rs | 2 +- rig-core/src/providers/voyageai.rs | 23 +--- rig-core/src/providers/xai/client.rs | 47 +------ rig-core/src/providers/xai/completion.rs | 2 +- rig-eternalai/src/providers/eternalai.rs | 31 ++++- rig-lancedb/tests/integration_tests.rs | 3 +- rig-mongodb/tests/integration_tests.rs | 6 +- rig-neo4j/tests/integration_tests.rs | 3 +- 52 files changed, 418 insertions(+), 577 deletions(-) diff --git a/rig-core/examples/agent_with_galadriel.rs b/rig-core/examples/agent_with_galadriel.rs index e5a160bf2..cd8866957 100644 --- a/rig-core/examples/agent_with_galadriel.rs +++ b/rig-core/examples/agent_with_galadriel.rs @@ -12,7 +12,7 @@ async fn main() -> Result<(), anyhow::Error> { if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() { builder = builder.fine_tune_api_key(fine_tune_api_key); } - let client = builder.build().expect("Failed to build client"); + let client = builder.build(); // Create agent with a single context prompt let comedian_agent = client diff --git a/rig-core/examples/pdf_agent.rs b/rig-core/examples/pdf_agent.rs index b664e4999..ab309929f 100644 --- a/rig-core/examples/pdf_agent.rs +++ b/rig-core/examples/pdf_agent.rs @@ -55,7 +55,7 @@ async fn main() -> Result<()> { // Initialize Ollama client let client = openai::Client::builder("ollama") .base_url("http://localhost:11434/v1") - .build()?; + .build(); // Load PDFs using Rig's built-in PDF loader let documents_dir = std::env::current_dir()?.join("rig-core/examples/documents"); diff --git a/rig-core/examples/vector_search_ollama.rs b/rig-core/examples/vector_search_ollama.rs index 1c23eedd8..076674c6f 100644 --- a/rig-core/examples/vector_search_ollama.rs +++ b/rig-core/examples/vector_search_ollama.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create ollama client let client = providers::ollama::Client::builder() .base_url("http://localhost:11434") - .build()?; + .build(); let embedding_model = client.embedding_model("nomic-embed-text"); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) diff --git a/rig-core/src/client/builder.rs b/rig-core/src/client/builder.rs index 03b1336fa..f093d1caa 100644 --- a/rig-core/src/client/builder.rs +++ b/rig-core/src/client/builder.rs @@ -78,93 +78,93 @@ impl<'a> DynClientBuilder { .register_all(vec![ ClientFactory::new( DefaultProviders::ANTHROPIC, - anthropic::Client::from_env_boxed, - anthropic::Client::from_val_boxed, + anthropic::Client::::from_env_boxed, + anthropic::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::COHERE, - cohere::Client::from_env_boxed, - cohere::Client::from_val_boxed, + cohere::Client::::from_env_boxed, + cohere::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::GEMINI, - gemini::Client::from_env_boxed, - gemini::Client::from_val_boxed, + gemini::Client::::from_env_boxed, + gemini::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::HUGGINGFACE, - huggingface::Client::from_env_boxed, - huggingface::Client::from_val_boxed, + huggingface::Client::::from_env_boxed, + huggingface::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::OPENAI, - openai::Client::from_env_boxed, - openai::Client::from_val_boxed, + openai::Client::::from_env_boxed, + openai::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::OPENROUTER, - openrouter::Client::from_env_boxed, - openrouter::Client::from_val_boxed, + openrouter::Client::::from_env_boxed, + openrouter::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::TOGETHER, - together::Client::from_env_boxed, - together::Client::from_val_boxed, + together::Client::::from_env_boxed, + together::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::XAI, - xai::Client::from_env_boxed, - xai::Client::from_val_boxed, + xai::Client::::from_env_boxed, + xai::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::AZURE, - azure::Client::from_env_boxed, - azure::Client::from_val_boxed, + azure::Client::::from_env_boxed, + azure::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::DEEPSEEK, - deepseek::Client::from_env_boxed, - deepseek::Client::from_val_boxed, + deepseek::Client::::from_env_boxed, + deepseek::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::GALADRIEL, - galadriel::Client::from_env_boxed, - galadriel::Client::from_val_boxed, + galadriel::Client::::from_env_boxed, + galadriel::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::GROQ, - groq::Client::from_env_boxed, - groq::Client::from_val_boxed, + groq::Client::::from_env_boxed, + groq::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::HYPERBOLIC, - hyperbolic::Client::from_env_boxed, - hyperbolic::Client::from_val_boxed, + hyperbolic::Client::::from_env_boxed, + hyperbolic::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::MOONSHOT, - moonshot::Client::from_env_boxed, - moonshot::Client::from_val_boxed, + moonshot::Client::::from_env_boxed, + moonshot::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::MIRA, - mira::Client::from_env_boxed, - mira::Client::from_val_boxed, + mira::Client::::from_env_boxed, + mira::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::MISTRAL, - mistral::Client::from_env_boxed, - mistral::Client::from_val_boxed, + mistral::Client::::from_env_boxed, + mistral::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::OLLAMA, - ollama::Client::from_env_boxed, - ollama::Client::from_val_boxed, + ollama::Client::::from_env_boxed, + ollama::Client::::from_val_boxed, ), ClientFactory::new( DefaultProviders::PERPLEXITY, - perplexity::Client::from_env_boxed, - perplexity::Client::from_val_boxed, + perplexity::Client::::from_env_boxed, + perplexity::Client::::from_val_boxed, ), ]) } diff --git a/rig-core/src/client/mod.rs b/rig-core/src/client/mod.rs index 22a4f19ef..364433004 100644 --- a/rig-core/src/client/mod.rs +++ b/rig-core/src/client/mod.rs @@ -327,16 +327,16 @@ mod tests { vec![ ClientConfig { name: "Anthropic", - factory_env: Box::new(anthropic::Client::from_env_boxed), - factory_val: Box::new(anthropic::Client::from_val_boxed), + factory_env: Box::new(anthropic::Client::::from_env_boxed), + factory_val: Box::new(anthropic::Client::::from_val_boxed), env_variable: "ANTHROPIC_API_KEY", completion_model: Some(anthropic::CLAUDE_3_5_SONNET), ..Default::default() }, ClientConfig { name: "Cohere", - factory_env: Box::new(cohere::Client::from_env_boxed), - factory_val: Box::new(cohere::Client::from_val_boxed), + factory_env: Box::new(cohere::Client::::from_env_boxed), + factory_val: Box::new(cohere::Client::::from_val_boxed), env_variable: "COHERE_API_KEY", completion_model: Some(cohere::COMMAND_R), embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2), @@ -344,8 +344,8 @@ mod tests { }, ClientConfig { name: "Gemini", - factory_env: Box::new(gemini::Client::from_env_boxed), - factory_val: Box::new(gemini::Client::from_val_boxed), + factory_env: Box::new(gemini::Client::::from_env_boxed), + factory_val: Box::new(gemini::Client::::from_val_boxed), env_variable: "GEMINI_API_KEY", completion_model: Some(gemini::completion::GEMINI_2_0_FLASH), embeddings_model: Some(gemini::embedding::EMBEDDING_001), @@ -354,8 +354,8 @@ mod tests { }, ClientConfig { name: "Huggingface", - factory_env: Box::new(huggingface::Client::from_env_boxed), - factory_val: Box::new(huggingface::Client::from_val_boxed), + factory_env: Box::new(huggingface::Client::::from_env_boxed), + factory_val: Box::new(huggingface::Client::::from_val_boxed), env_variable: "HUGGINGFACE_API_KEY", completion_model: Some(huggingface::PHI_4), transcription_model: Some(huggingface::WHISPER_SMALL), @@ -364,8 +364,8 @@ mod tests { }, ClientConfig { name: "OpenAI", - factory_env: Box::new(openai::Client::from_env_boxed), - factory_val: Box::new(openai::Client::from_val_boxed), + factory_env: Box::new(openai::Client::::from_env_boxed), + factory_val: Box::new(openai::Client::::from_val_boxed), env_variable: "OPENAI_API_KEY", completion_model: Some(openai::GPT_4O), embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002), @@ -375,16 +375,16 @@ mod tests { }, ClientConfig { name: "OpenRouter", - factory_env: Box::new(openrouter::Client::from_env_boxed), - factory_val: Box::new(openrouter::Client::from_val_boxed), + factory_env: Box::new(openrouter::Client::::from_env_boxed), + factory_val: Box::new(openrouter::Client::::from_val_boxed), env_variable: "OPENROUTER_API_KEY", completion_model: Some(openrouter::CLAUDE_3_7_SONNET), ..Default::default() }, ClientConfig { name: "Together", - factory_env: Box::new(together::Client::from_env_boxed), - factory_val: Box::new(together::Client::from_val_boxed), + factory_env: Box::new(together::Client::::from_env_boxed), + factory_val: Box::new(together::Client::::from_val_boxed), env_variable: "TOGETHER_API_KEY", completion_model: Some(together::ALPACA_7B), embeddings_model: Some(together::BERT_BASE_UNCASED), @@ -392,8 +392,8 @@ mod tests { }, ClientConfig { name: "XAI", - factory_env: Box::new(xai::Client::from_env_boxed), - factory_val: Box::new(xai::Client::from_val_boxed), + factory_env: Box::new(xai::Client::::from_env_boxed), + factory_val: Box::new(xai::Client::::from_val_boxed), env_variable: "XAI_API_KEY", completion_model: Some(xai::GROK_3_MINI), embeddings_model: None, @@ -401,8 +401,8 @@ mod tests { }, ClientConfig { name: "Azure", - factory_env: Box::new(azure::Client::from_env_boxed), - factory_val: Box::new(azure::Client::from_val_boxed), + factory_env: Box::new(azure::Client::::from_env_boxed), + factory_val: Box::new(azure::Client::::from_val_boxed), env_variable: "AZURE_API_KEY", completion_model: Some(azure::GPT_4O), embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002), @@ -412,24 +412,24 @@ mod tests { }, ClientConfig { name: "Deepseek", - factory_env: Box::new(deepseek::Client::from_env_boxed), - factory_val: Box::new(deepseek::Client::from_val_boxed), + factory_env: Box::new(deepseek::Client::::from_env_boxed), + factory_val: Box::new(deepseek::Client::::from_val_boxed), env_variable: "DEEPSEEK_API_KEY", completion_model: Some(deepseek::DEEPSEEK_CHAT), ..Default::default() }, ClientConfig { name: "Galadriel", - factory_env: Box::new(galadriel::Client::from_env_boxed), - factory_val: Box::new(galadriel::Client::from_val_boxed), + factory_env: Box::new(galadriel::Client::::from_env_boxed), + factory_val: Box::new(galadriel::Client::::from_val_boxed), env_variable: "GALADRIEL_API_KEY", completion_model: Some(galadriel::GPT_4O), ..Default::default() }, ClientConfig { name: "Groq", - factory_env: Box::new(groq::Client::from_env_boxed), - factory_val: Box::new(groq::Client::from_val_boxed), + factory_env: Box::new(groq::Client::::from_env_boxed), + factory_val: Box::new(groq::Client::::from_val_boxed), env_variable: "GROQ_API_KEY", completion_model: Some(groq::MIXTRAL_8X7B_32768), transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3), @@ -437,8 +437,8 @@ mod tests { }, ClientConfig { name: "Hyperbolic", - factory_env: Box::new(hyperbolic::Client::from_env_boxed), - factory_val: Box::new(hyperbolic::Client::from_val_boxed), + factory_env: Box::new(hyperbolic::Client::::from_env_boxed), + factory_val: Box::new(hyperbolic::Client::::from_val_boxed), env_variable: "HYPERBOLIC_API_KEY", completion_model: Some(hyperbolic::LLAMA_3_1_8B), image_generation_model: Some(hyperbolic::SD1_5), @@ -447,24 +447,24 @@ mod tests { }, ClientConfig { name: "Mira", - factory_env: Box::new(mira::Client::from_env_boxed), - factory_val: Box::new(mira::Client::from_val_boxed), + factory_env: Box::new(mira::Client::::from_env_boxed), + factory_val: Box::new(mira::Client::::from_val_boxed), env_variable: "MIRA_API_KEY", completion_model: Some("gpt-4o"), ..Default::default() }, ClientConfig { name: "Moonshot", - factory_env: Box::new(moonshot::Client::from_env_boxed), - factory_val: Box::new(moonshot::Client::from_val_boxed), + factory_env: Box::new(moonshot::Client::::from_env_boxed), + factory_val: Box::new(moonshot::Client::::from_val_boxed), env_variable: "MOONSHOT_API_KEY", completion_model: Some(moonshot::MOONSHOT_CHAT), ..Default::default() }, ClientConfig { name: "Ollama", - factory_env: Box::new(ollama::Client::from_env_boxed), - factory_val: Box::new(ollama::Client::from_val_boxed), + factory_env: Box::new(ollama::Client::::from_env_boxed), + factory_val: Box::new(ollama::Client::::from_val_boxed), env_variable: "OLLAMA_ENABLED", completion_model: Some("llama3.1:8b"), embeddings_model: Some(ollama::NOMIC_EMBED_TEXT), @@ -472,8 +472,8 @@ mod tests { }, ClientConfig { name: "Perplexity", - factory_env: Box::new(perplexity::Client::from_env_boxed), - factory_val: Box::new(perplexity::Client::from_val_boxed), + factory_env: Box::new(perplexity::Client::::from_env_boxed), + factory_val: Box::new(perplexity::Client::::from_val_boxed), env_variable: "PERPLEXITY_API_KEY", completion_model: Some(perplexity::SONAR), ..Default::default() diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 479a3470b..6a5a033f9 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -44,13 +44,20 @@ pub async fn text(response: Response>>) -> Result { Ok(String::from(String::from_utf8_lossy(&text))) } +pub fn with_bearer_auth(req: Builder, auth: &str) -> Result { + let auth_header = + HeaderValue::from_str(&format!("Bearer {}", auth)).map_err(http::Error::from)?; + + Ok(req.header("Authorization", auth_header)) +} + pub trait HttpClientExt: Send + Sync { fn request( &self, req: Request, ) -> impl Future>>> + Send where - T: Into, + T: Into + Send, U: From + Send; fn request_streaming( @@ -60,31 +67,36 @@ pub trait HttpClientExt: Send + Sync { where T: Into; - async fn get(&self, uri: Uri) -> Result>> + fn get(&self, uri: Uri) -> impl Future>>> + Send where T: From + Send, { - let req = Request::builder() - .method(Method::GET) - .uri(uri) - .body(NoBody)?; + async { + let req = Request::builder() + .method(Method::GET) + .uri(uri) + .body(NoBody)?; - self.request(req).await + self.request(req).await + } } - async fn post(&self, uri: Uri, body: T) -> Result>> + fn post( + &self, + uri: Uri, + body: T, + ) -> impl Future>>> + Send where - U: TryInto, - >::Error: Into, - T: Into, - V: From + Send, + T: Into + Send, + R: From + Send, { - let req = Request::builder() - .method(Method::POST) - .uri(uri) - .body(body.into())?; - - self.request(req).await + async { + let req = Request::builder() + .method(Method::POST) + .uri(uri) + .body(body)?; + self.request(req).await + } } } diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs index 4774587d6..e75c92147 100644 --- a/rig-core/src/providers/anthropic/client.rs +++ b/rig-core/src/providers/anthropic/client.rs @@ -1,6 +1,6 @@ //! Anthropic client api implementation use bytes::Bytes; -use http_client::{Method, Request, Uri}; +use http_client::{Method, Request}; use super::completion::{ANTHROPIC_VERSION_LATEST, CompletionModel}; use crate::{ @@ -16,7 +16,7 @@ use crate::{ // ================================================================ const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, anthropic_version: &'a str, @@ -24,6 +24,18 @@ pub struct ClientBuilder<'a, T> { http_client: T, } +impl<'a> ClientBuilder<'a, reqwest::Client> { + pub fn new(api_key: &'a str) -> Self { + ClientBuilder { + api_key, + base_url: ANTHROPIC_API_BASE_URL, + anthropic_version: ANTHROPIC_VERSION_LATEST, + anthropic_betas: None, + http_client: Default::default(), + } + } +} + /// Create a new anthropic client using the builder /// /// # Example @@ -38,24 +50,24 @@ pub struct ClientBuilder<'a, T> { /// ``` impl<'a, T> ClientBuilder<'a, T> where - T: HttpClientExt + Default, + T: HttpClientExt, { - pub fn new(api_key: &'a str) -> Self { - ClientBuilder { + pub fn new_with_client(api_key: &'a str, http_client: T) -> Self { + Self { api_key, base_url: ANTHROPIC_API_BASE_URL, anthropic_version: ANTHROPIC_VERSION_LATEST, anthropic_betas: None, - http_client: Default::default(), + http_client, } } - pub fn with_client(api_key: &'a str, http_client: T) -> Self { - Self { - api_key, - base_url: ANTHROPIC_API_BASE_URL, - anthropic_version: ANTHROPIC_VERSION_LATEST, - anthropic_betas: None, + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + base_url: self.base_url, + anthropic_version: self.anthropic_version, + anthropic_betas: self.anthropic_betas, http_client, } } @@ -109,7 +121,7 @@ where } #[derive(Clone)] -pub struct Client { +pub struct Client { /// The base URL base_url: String, /// The API key @@ -134,35 +146,16 @@ where } } -fn build_uri(path: &str) -> Result { - Uri::builder() - .scheme("https") - .authority("api.anthropic.com") - .path_and_query(path) - .build() -} - impl Client where T: HttpClientExt + Clone + Default, { - /// Create a new Anthropic client. For more control, use the `builder` method. - /// - /// # Panics - /// - If the API key or version cannot be parsed as a Json value from a String. - /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). - pub fn new(api_key: &str) -> Self { - ClientBuilder::new(api_key) - .build() - .expect("Anthropic client should build") - } - pub async fn send( &self, req: http_client::Request, ) -> Result>, http_client::Error> where - U: Into, + U: Into + Send, V: From + Send, { self.http_client.request(req).await @@ -199,10 +192,7 @@ where req } - pub(crate) fn get( - &self, - path: &str, - ) -> Result, http::Error> { + pub(crate) fn get(&self, path: &str) -> http_client::Builder { let uri = format!("{}/{}", self.base_url, path).replace("//", "/"); let mut headers = self.default_headers.clone(); @@ -217,14 +207,24 @@ where *hs = headers; } - req.body(http_client::NoBody) + req } } -impl ProviderClient for Client -where - T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static, -{ +impl Client { + /// Create a new Anthropic client. For more control, use the `builder` method. + /// + /// # Panics + /// - If the API key or version cannot be parsed as a Json value from a String. + /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). + pub fn new(api_key: &str) -> Self { + ClientBuilder::new(api_key) + .build() + .expect("Anthropic client should build") + } +} + +impl ProviderClient for Client { /// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self { @@ -242,30 +242,23 @@ where } } -impl CompletionClient for Client -where - T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static, -{ - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; - fn completion_model(&self, model: &str) -> CompletionModel { + fn completion_model(&self, model: &str) -> CompletionModel { CompletionModel::new(self.clone(), model) } } -impl VerifyClient for Client -where - T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static, -{ +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { - let response: http_client::Response>> = self - .http_client - .request( - self.get("/v1/models") - .map_err(|e| http_client::Error::Protocol(e))?, - ) - .await?; + let req = self + .get("/v1/models") + .body(http_client::NoBody) + .map_err(http_client::Error::from)?; + + let response = HttpClientExt::request(&self.http_client, req).await?; match response.status() { http::StatusCode::OK => Ok(()), @@ -273,11 +266,11 @@ where Err(VerifyError::InvalidAuthentication) } http::StatusCode::INTERNAL_SERVER_ERROR => { - let text = String::from_utf8_lossy(&response.into_body().await?).into(); + let text = http_client::text(response).await?; Err(VerifyError::ProviderError(text)) } status if status.as_u16() == 529 => { - let text = String::from_utf8_lossy(&response.into_body().await?).into(); + let text = http_client::text(response).await?; Err(VerifyError::ProviderError(text)) } _ => { diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index 2da946021..da68995d8 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -578,7 +578,7 @@ impl TryFrom for message::Message { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, pub model: String, pub default_max_tokens: Option, @@ -720,14 +720,14 @@ where .client .send::<_, Bytes>(req) .await - .map_err(|e| CompletionError::HttpError(e.into()))?; + .map_err(CompletionError::HttpError)?; if response.status().is_success() { match serde_json::from_slice::>( response .into_body() .await - .map_err(|e| CompletionError::HttpError(e.into()))? + .map_err(CompletionError::HttpError)? .to_vec() .as_slice(), )? { @@ -750,7 +750,7 @@ where &response .into_body() .await - .map_err(|e| CompletionError::HttpError(e.into()))?, + .map_err(CompletionError::HttpError)?, ) .into(); Err(CompletionError::ProviderError(text)) diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index 0cb090fd9..da9559da7 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -163,7 +163,7 @@ where .client .post("/v1/messages") .body(body) - .map_err(|e| http_client::Error::Protocol(e))?; + .map_err(http_client::Error::Protocol)?; let response: http_client::StreamingResponse = self.client.send_streaming(req).await?; diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index e3af36f1b..e648b3995 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -15,7 +15,6 @@ use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::streaming::StreamingCompletionResponse; use crate::{ - client::ClientBuilderError, completion::{self, CompletionError, CompletionRequest}, embeddings::{self, EmbeddingError}, json_utils, @@ -33,7 +32,7 @@ use serde_json::json; const DEFAULT_API_VERSION: &str = "2024-10-21"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { auth: AzureOpenAIAuth, api_version: Option<&'a str>, azure_endpoint: &'a str, @@ -89,7 +88,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { api_version: String, azure_endpoint: String, auth: AzureOpenAIAuth, @@ -192,23 +191,12 @@ where self.post(url) } - #[cfg(feature = "audio")] - fn post_audio_generation(&self, deployment_id: &str) -> http_client::Builder { - let url = format!( - "{}/openai/deployments/{}/audio/speech?api-version={}", - self.azure_endpoint, deployment_id, self.api_version - ) - .replace("//", "/"); - - self.post(url) - } - async fn send( &self, req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, R: From + Send, { self.http_client.request(req).await @@ -222,6 +210,17 @@ impl Client { self.http_client.post(url).header(key, val) } + #[cfg(feature = "audio")] + fn post_audio_generation(&self, deployment_id: &str) -> reqwest::RequestBuilder { + let url = format!( + "{}/openai/deployments/{}/audio/speech?api-version={}", + self.azure_endpoint, deployment_id, self.api_version + ) + .replace("//", "/"); + + self.post_reqwest(url) + } + fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder { let url = format!( "{}/openai/deployments/{}/chat/completions?api-version={}", @@ -436,7 +435,7 @@ impl std::fmt::Display for Usage { } #[derive(Clone)] -pub struct EmbeddingModel { +pub struct EmbeddingModel { client: Client, pub model: String, ndims: usize, @@ -548,7 +547,7 @@ pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k"; #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { client: Client, /// Name of the model (e.g.: gpt-4o-mini) pub model: String, @@ -683,7 +682,7 @@ impl completion::CompletionModel for CompletionModel { // ================================================================ #[derive(Clone)] -pub struct TranscriptionModel { +pub struct TranscriptionModel { client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, @@ -779,7 +778,7 @@ mod image_generation { use serde_json::json; #[derive(Clone)] - pub struct ImageGenerationModel { + pub struct ImageGenerationModel { client: Client, pub model: String, } @@ -856,16 +855,16 @@ pub use audio_generation::*; #[cfg_attr(docsrs, doc(cfg(feature = "audio")))] mod audio_generation { use super::Client; - use crate::audio_generation; use crate::audio_generation::{ AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse, }; use crate::client::AudioGenerationClient; + use crate::{audio_generation, http_client}; use bytes::Bytes; use serde_json::json; #[derive(Clone)] - pub struct AudioGenerationModel { + pub struct AudioGenerationModel { client: Client, model: String, } @@ -890,17 +889,24 @@ mod audio_generation { .post_audio_generation("/audio/speech") .json(&request) .send() - .await?; + .await + .map_err(|e| { + AudioGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?; if !response.status().is_success() { return Err(AudioGenerationError::ProviderError(format!( "{}: {}", response.status(), - response.text().await? + response.text().await.map_err(|e| { + AudioGenerationError::HttpError(http_client::Error::Instance(e.into())) + })? ))); } - let bytes = response.bytes().await?; + let bytes = response.bytes().await.map_err(|e| { + AudioGenerationError::HttpError(http_client::Error::Instance(e.into())) + })?; Ok(AudioGenerationResponse { audio: bytes.to_vec(), diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index 892b7dd04..6526bbbf2 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -28,28 +28,32 @@ pub enum ApiResponse { // ================================================================ const COHERE_API_BASE_URL: &str = "https://api.cohere.ai"; -pub struct ClientBuilder<'a, T> -where - T: HttpClientExt, -{ +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, } -impl<'a, T> ClientBuilder<'a, T> -where - T: HttpClientExt, -{ +impl<'a> ClientBuilder<'a, reqwest::Client> { pub fn new(api_key: &'a str) -> ClientBuilder<'a, reqwest::Client> { ClientBuilder { api_key, base_url: COHERE_API_BASE_URL, - http_client: reqwest::Client::new(), + http_client: Default::default(), + } + } +} + +impl<'a, T> ClientBuilder<'a, T> { + pub fn new_with_client(api_key: &'a str, http_client: T) -> Self { + ClientBuilder { + api_key, + base_url: COHERE_API_BASE_URL, + http_client, } } - pub fn with_client(api_key: &'a str, http_client: T) -> Self { + pub fn with_client(api_key: &str, http_client: U) -> ClientBuilder<'_, U> { ClientBuilder { api_key, base_url: COHERE_API_BASE_URL, @@ -57,7 +61,7 @@ where } } - pub fn base_url(mut self, base_url: &'a str) -> Self { + pub fn base_url(mut self, base_url: &'a str) -> ClientBuilder<'a, T> { self.base_url = base_url; self } @@ -72,7 +76,7 @@ where } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -91,37 +95,35 @@ where } } -impl Client -where - T: HttpClientExt + Clone, -{ +impl Client { /// Create a new Cohere client. For more control, use the `builder` method. /// /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). - pub fn new(api_key: &str) -> Client { - ClientBuilder::with_client(api_key, reqwest::Client::new()).build() + pub fn new(api_key: &str) -> Self { + ClientBuilder::new(api_key).build() } +} - pub(crate) fn post(&self, path: &str) -> http_client::Result - where - U: From + Send, - { +impl Client +where + T: HttpClientExt + Clone, +{ + fn req( + &self, + method: http_client::Method, + path: &str, + ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = - http_client::HeaderValue::try_from(format!("Bearer {}", self.api_key.as_str())) - .map_err(http::Error::from)?; - Ok(http_client::Request::post(url).header("Authorization", auth_header)) + http_client::with_bearer_auth( + http_client::Builder::new().method(method).uri(url), + &self.api_key, + ) } - pub(crate) fn get(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = - http_client::HeaderValue::try_from(format!("Bearer {}", self.api_key.as_str())) - .map_err(http::Error::from)?; - - Ok(http_client::Request::get(url).header("Authorization", auth_header)) + pub(crate) fn post(&self, path: &str) -> http_client::Result { + self.req(http_client::Method::POST, path) } pub(crate) async fn send( @@ -129,7 +131,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, V: From + Send, { self.http_client.request(req).await @@ -170,7 +172,7 @@ where } impl Client { - pub(crate) async fn eventsource( + pub(crate) async fn eventsource( &self, req: reqwest::RequestBuilder, ) -> Result { diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index ac6bf1e32..a7d3d299e 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -456,7 +456,7 @@ impl TryFrom for message::Message { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, pub model: String, } diff --git a/rig-core/src/providers/cohere/embeddings.rs b/rig-core/src/providers/cohere/embeddings.rs index f8a60746a..b2a08b361 100644 --- a/rig-core/src/providers/cohere/embeddings.rs +++ b/rig-core/src/providers/cohere/embeddings.rs @@ -59,7 +59,7 @@ impl std::fmt::Display for BilledUnits { } #[derive(Clone)] -pub struct EmbeddingModel { +pub struct EmbeddingModel { client: Client, pub model: String, pub input_type: String, @@ -93,8 +93,7 @@ where let req = self .client - .post::>("/v1/embed") - .map_err(|e| EmbeddingError::HttpError(e.into()))? + .post("/v1/embed")? .body(body) .map_err(|e| EmbeddingError::HttpError(e.into()))?; @@ -102,7 +101,7 @@ where .client .send::<_, Vec>(req) .await - .map_err(|e| EmbeddingError::HttpError(e.into()))?; + .map_err(EmbeddingError::HttpError)?; if response.status().is_success() { let body: ApiResponse = diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index f6b6b6f13..f15a310c3 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -102,7 +102,7 @@ impl CompletionModel { let mut event_source = self .client - .eventsource::(req) + .eventsource(req) .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 7166f94d3..780273df4 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -15,9 +15,7 @@ use futures::StreamExt; use reqwest_eventsource::{Event, RequestBuilderExt}; use std::collections::HashMap; -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, -}; +use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}; use crate::completion::GetTokenUsage; use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; @@ -27,7 +25,6 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, impl_conversion_traits, json_utils, message, }; -use reqwest::Client as HttpClient; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -38,7 +35,7 @@ use super::openai::StreamingToolCall; // ================================================================ const DEEPSEEK_API_BASE_URL: &str = "https://api.deepseek.com"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -81,7 +78,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { pub base_url: String, api_key: String, http_client: T, @@ -138,17 +135,10 @@ where ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - Ok(http_client::Request::builder() - .method(method) - .uri(url) - .header("Authorization", auth_header)) - } - - pub(crate) fn post(&self, path: &str) -> http_client::Result { - self.req(http_client::Method::POST, path) + http_client::with_bearer_auth( + http_client::Request::builder().method(method).uri(url), + &self.api_key, + ) } pub(crate) fn get(&self, path: &str) -> http_client::Result { @@ -160,7 +150,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, R: From + Send, { self.http_client.request(req).await @@ -566,7 +556,7 @@ impl TryFrom for completion::CompletionResponse { +pub struct CompletionModel { pub client: Client, pub model: String, } diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 472c8b46b..b2b910db4 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -11,9 +11,7 @@ //! let gpt4o = client.completion_model(galadriel::GPT_4O); //! ``` use super::openai; -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, -}; +use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}; use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::message::MessageError; @@ -33,7 +31,7 @@ use serde_json::{Value, json}; // ================================================================ const GALADRIEL_API_BASE_URL: &str = "https://api.galadriel.com/v1/verified"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, fine_tune_api_key: Option<&'a str>, base_url: &'a str, @@ -84,7 +82,7 @@ impl<'a, T> ClientBuilder<'a, T> { } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, fine_tune_api_key: Option, @@ -139,16 +137,13 @@ where pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - let mut req = http_client::Request::post(url).header("Authorization", auth_header); + let mut req = http_client::Request::post(url); if let Some(fine_tune_key) = self.fine_tune_api_key.clone() { req = req.header("Fine-Tune-Authorization", fine_tune_key); } - Ok(req) + http_client::with_bearer_auth(req, &self.api_key) } async fn send( @@ -156,7 +151,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, R: From + Send, { self.http_client.request(req).await @@ -508,7 +503,7 @@ pub struct Function { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index e6838e881..e521d58c7 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -5,7 +5,6 @@ use crate::client::{ ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, TranscriptionClient, VerifyClient, VerifyError, impl_conversion_traits, }; -use crate::completion; use crate::http_client::{self, HttpClientExt}; use crate::{ Embed, @@ -20,7 +19,7 @@ use std::fmt::Debug; // ================================================================ const GEMINI_API_BASE_URL: &str = "https://generativelanguage.googleapis.com"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -78,7 +77,7 @@ where } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, default_headers: reqwest::header::HeaderMap, @@ -179,7 +178,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, R: From + Send, { self.http_client.request(req).await diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 68a110ac3..70360f3b4 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -28,7 +28,6 @@ pub const GEMINI_1_5_PRO_8B: &str = "gemini-1.5-pro-8b"; pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; use self::gemini_api_types::Schema; -use crate::http_client::HttpClientExt; use crate::message::Reasoning; use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters; use crate::providers::gemini::streaming::StreamingCompletionResponse; @@ -50,7 +49,7 @@ use super::Client; // ================================================================= #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, pub model: String, } @@ -93,7 +92,7 @@ impl completion::CompletionModel for CompletionModel { let response_body = response .into_body() .await - .map_err(|e| CompletionError::HttpError(e))?; + .map_err(CompletionError::HttpError)?; let body: GenerateContentResponse = serde_json::from_slice(&response_body)?; @@ -115,7 +114,7 @@ impl completion::CompletionModel for CompletionModel { &response .into_body() .await - .map_err(|e| CompletionError::HttpError(e.into()))?, + .map_err(CompletionError::HttpError)?, ) .into(); diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index f4da9de1a..1614fa6c4 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -17,7 +17,7 @@ pub const EMBEDDING_001: &str = "embedding-001"; /// `text-embedding-004` embedding model pub const EMBEDDING_004: &str = "text-embedding-004"; #[derive(Clone)] -pub struct EmbeddingModel { +pub struct EmbeddingModel { client: Client, model: String, ndims: Option, diff --git a/rig-core/src/providers/gemini/transcription.rs b/rig-core/src/providers/gemini/transcription.rs index 6b67f4f21..23a79ae08 100644 --- a/rig-core/src/providers/gemini/transcription.rs +++ b/rig-core/src/providers/gemini/transcription.rs @@ -22,7 +22,7 @@ const TRANSCRIPTION_PREAMBLE: &str = "Translate the provided audio exactly. Do not add additional information."; #[derive(Clone)] -pub struct TranscriptionModel { +pub struct TranscriptionModel { client: Client, /// Name of the model (e.g.: gemini-1.5-flash) pub model: String, diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index d481f0874..93de12759 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -39,7 +39,7 @@ use serde_json::{Value, json}; // ================================================================ const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -82,7 +82,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -139,17 +139,10 @@ where ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - Ok(http_client::Builder::new() - .method(method) - .uri(url) - .header("Authorization", auth_header)) - } - - fn post(&self, path: &str) -> http_client::Result { - self.req(http_client::Method::POST, path) + http_client::with_bearer_auth( + http_client::Builder::new().method(method).uri(url), + &self.api_key, + ) } fn get(&self, path: &str) -> http_client::Result { diff --git a/rig-core/src/providers/huggingface/client.rs b/rig-core/src/providers/huggingface/client.rs index 5ad78f42d..304922f31 100644 --- a/rig-core/src/providers/huggingface/client.rs +++ b/rig-core/src/providers/huggingface/client.rs @@ -108,7 +108,7 @@ impl Display for SubProvider { } } -pub struct ClientBuilder { +pub struct ClientBuilder { api_key: String, base_url: String, sub_provider: SubProvider, @@ -172,7 +172,7 @@ impl ClientBuilder { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, default_headers: reqwest::header::HeaderMap, api_key: String, @@ -180,7 +180,7 @@ pub struct Client { pub(crate) sub_provider: SubProvider, } -impl std::fmt::Debug for Client +impl Debug for Client where T: Debug, { @@ -225,10 +225,6 @@ where } impl Client { - pub(crate) fn client(&self) -> &reqwest::Client { - &self.http_client - } - pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); @@ -246,31 +242,25 @@ where pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - let mut req = http_client::Request::post(url).header("Authorization", auth_header); + let mut req = http_client::Request::post(url); if let Some(hs) = req.headers_mut() { *hs = self.default_headers.clone(); } - Ok(req) + http_client::with_bearer_auth(req, &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - let mut req = http_client::Request::get(url).header("Authorization", auth_header); + let mut req = http_client::Request::get(url); if let Some(hs) = req.headers_mut() { *hs = self.default_headers.clone(); } - Ok(req) + http_client::with_bearer_auth(req, &self.api_key) } pub(crate) async fn send( @@ -278,7 +268,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, V: From + Send, { self.http_client.request(req).await diff --git a/rig-core/src/providers/huggingface/completion.rs b/rig-core/src/providers/huggingface/completion.rs index be39ea2cf..f94a5d818 100644 --- a/rig-core/src/providers/huggingface/completion.rs +++ b/rig-core/src/providers/huggingface/completion.rs @@ -3,7 +3,6 @@ use serde_json::{Value, json}; use std::{convert::Infallible, str::FromStr}; use super::client::Client; -use crate::http_client::HttpClientExt; use crate::providers::openai::StreamingCompletionResponse; use crate::{ OneOrMany, @@ -497,7 +496,7 @@ impl TryFrom for completion::CompletionResponse { +pub struct CompletionModel { pub(crate) client: Client, /// Name of the model (e.g: google/gemma-2-2b-it) pub model: String, diff --git a/rig-core/src/providers/huggingface/image_generation.rs b/rig-core/src/providers/huggingface/image_generation.rs index 458338a08..755eddb66 100644 --- a/rig-core/src/providers/huggingface/image_generation.rs +++ b/rig-core/src/providers/huggingface/image_generation.rs @@ -27,7 +27,7 @@ impl TryFrom } #[derive(Clone)] -pub struct ImageGenerationModel { +pub struct ImageGenerationModel { client: Client, pub model: String, } diff --git a/rig-core/src/providers/huggingface/transcription.rs b/rig-core/src/providers/huggingface/transcription.rs index a776d5142..b2618280d 100644 --- a/rig-core/src/providers/huggingface/transcription.rs +++ b/rig-core/src/providers/huggingface/transcription.rs @@ -1,4 +1,4 @@ -use crate::http_client::{self, HttpClientExt}; +use crate::http_client::HttpClientExt; use crate::providers::huggingface::Client; use crate::providers::huggingface::completion::ApiResponse; use crate::transcription; @@ -31,7 +31,7 @@ impl TryFrom } #[derive(Clone)] -pub struct TranscriptionModel { +pub struct TranscriptionModel { client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 78f316bdb..bbf6cc2fd 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -10,9 +10,7 @@ //! ``` use super::openai::{AssistantContent, send_compatible_streaming_request}; -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, -}; +use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}; use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge_inplace; use crate::message; @@ -34,7 +32,7 @@ use serde_json::{Value, json}; // ================================================================ const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -77,7 +75,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -134,17 +132,10 @@ where ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - Ok(http_client::Builder::new() - .method(method) - .uri(url) - .header("Authorization", auth_header)) - } - - fn post(&self, path: &str) -> http_client::Result { - self.req(http_client::Method::POST, path) + http_client::with_bearer_auth( + http_client::Builder::new().method(method).uri(url), + &self.api_key, + ) } fn get(&self, path: &str) -> http_client::Result { @@ -661,6 +652,7 @@ mod audio_generation { use crate::audio_generation; use crate::audio_generation::{AudioGenerationError, AudioGenerationRequest}; use crate::client::AudioGenerationClient; + use crate::http_client; use base64::Engine; use base64::prelude::BASE64_STANDARD; use serde::Deserialize; @@ -673,7 +665,7 @@ mod audio_generation { } impl AudioGenerationModel { - pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel { + pub(crate) fn new(client: Client, language: &str) -> AudioGenerationModel { Self { client, language: language.to_string(), @@ -763,7 +755,7 @@ mod audio_generation { /// /// let tts = hyperbolic.audio_generation_model("EN"); /// ``` - fn audio_generation_model(&self, language: &str) -> AudioGenerationModel { + fn audio_generation_model(&self, language: &str) -> AudioGenerationModel { AudioGenerationModel::new(self.clone(), language) } } diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index ed75566dd..7ec38dd00 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -110,7 +110,7 @@ struct ModelInfo { id: String, } -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -166,7 +166,7 @@ impl<'a, T> ClientBuilder<'a, T> { #[derive(Clone)] /// Client for interacting with the Mira API -pub struct Client { +pub struct Client { base_url: String, http_client: T, api_key: String, @@ -222,7 +222,7 @@ where pub async fn list_models(&self) -> Result, MiraError> { let req = self.get("/v1/models").and_then(|req| { req.body(http_client::NoBody) - .map_err(|e| http_client::Error::Protocol(e.into())) + .map_err(http_client::Error::Protocol) })?; let response = self.http_client.request(req).await?; @@ -253,23 +253,13 @@ where ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - let mut req = http_client::Builder::new() - .method(method) - .uri(url) - .header("Authorization", auth_header); + let mut req = http_client::Builder::new().method(method).uri(url); if let Some(hs) = req.headers_mut() { *hs = self.headers.clone(); } - Ok(req) - } - - pub(crate) fn post(&self, path: &str) -> http_client::Result { - self.req(http_client::Method::POST, path) + http_client::with_bearer_auth(req, &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { diff --git a/rig-core/src/providers/mistral/client.rs b/rig-core/src/providers/mistral/client.rs index d88e47523..48f127b31 100644 --- a/rig-core/src/providers/mistral/client.rs +++ b/rig-core/src/providers/mistral/client.rs @@ -14,16 +14,13 @@ use std::fmt::Debug; const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, } -impl<'a, T> ClientBuilder<'a, T> -where - T: Default, -{ +impl<'a> ClientBuilder<'a, reqwest::Client> { pub fn new(api_key: &'a str) -> Self { Self { api_key, @@ -57,7 +54,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -76,23 +73,15 @@ where } } -impl Client -where - T: Default, -{ +impl Client { /// Create a new Mistral client. For more control, use the `builder` method. /// /// # Panics /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). - pub fn new(api_key: &str) -> Client { + pub fn new(api_key: &str) -> Self { Self::builder(api_key).build() } -} -impl Client -where - T: Default, -{ /// Create a new Mistral client builder. /// /// # Example @@ -103,7 +92,7 @@ where /// let mistral = Client::builder("your-mistral-api-key") /// .build() /// ``` - pub fn builder(api_key: &str) -> ClientBuilder<'_, T> { + pub fn builder(api_key: &str) -> ClientBuilder<'_> { ClientBuilder::new(api_key) } } @@ -115,19 +104,13 @@ where pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - Ok(http_client::Request::post(url).header("Authorization", auth_header)) + http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - Ok(http_client::Request::get(url).header("Authorization", auth_header)) + http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) } pub(crate) async fn send( @@ -135,17 +118,14 @@ where req: http_client::Request, ) -> http_client::Result>> where - Body: Into, + Body: Into + Send, R: From + Send, { self.http_client.request(req).await } } -impl ProviderClient for Client -where - T: HttpClientExt + Debug + Default + Clone + 'static, -{ +impl ProviderClient for Client { /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable. /// Panics if the environment variable is not set. fn from_env() -> Self @@ -164,11 +144,8 @@ where } } -impl CompletionClient for Client -where - T: HttpClientExt + Debug + Default + Clone + 'static, -{ - type CompletionModel = CompletionModel; +impl CompletionClient for Client { + type CompletionModel = CompletionModel; /// Create a completion model with the given name. /// @@ -186,11 +163,8 @@ where } } -impl EmbeddingsClient for Client -where - T: HttpClientExt + Debug + Default + Clone + 'static, -{ - type EmbeddingModel = EmbeddingModel; +impl EmbeddingsClient for Client { + type EmbeddingModel = EmbeddingModel; /// Create an embedding model with the given name. /// Note: default embedding dimension of 0 will be used if model is not known. @@ -204,7 +178,7 @@ where /// /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED); /// ``` - fn embedding_model(&self, model: &str) -> EmbeddingModel { + fn embedding_model(&self, model: &str) -> Self::EmbeddingModel { let ndims = match model { MISTRAL_EMBED => 1024, _ => 0, @@ -217,10 +191,7 @@ where } } -impl VerifyClient for Client -where - T: HttpClientExt + Debug + Default + Clone + 'static, -{ +impl VerifyClient for Client { #[cfg_attr(feature = "worker", worker::send)] async fn verify(&self) -> Result<(), VerifyError> { let req = self @@ -228,14 +199,13 @@ where .body(http_client::NoBody) .map_err(|e| VerifyError::HttpError(e.into()))?; - let response = self.http_client.request(req).await?; + let response = HttpClientExt::request(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR => { - let text: Vec = response.into_body().await?; - let text = String::from_utf8_lossy(&text).into(); + let text = http_client::text(response).await?; Err(VerifyError::ProviderError(text)) } _ => { diff --git a/rig-core/src/providers/mistral/completion.rs b/rig-core/src/providers/mistral/completion.rs index 3eefacb5a..5a5e5820d 100644 --- a/rig-core/src/providers/mistral/completion.rs +++ b/rig-core/src/providers/mistral/completion.rs @@ -251,7 +251,7 @@ impl FromStr for AssistantContent { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, pub model: String, } diff --git a/rig-core/src/providers/mistral/embedding.rs b/rig-core/src/providers/mistral/embedding.rs index fa9557256..702819a43 100644 --- a/rig-core/src/providers/mistral/embedding.rs +++ b/rig-core/src/providers/mistral/embedding.rs @@ -16,7 +16,7 @@ pub const MISTRAL_EMBED: &str = "mistral-embed"; pub const MAX_DOCUMENTS: usize = 1024; #[derive(Clone)] -pub struct EmbeddingModel { +pub struct EmbeddingModel { client: Client, pub model: String, ndims: usize, diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index d6128ca14..b8c5d6196 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -8,9 +8,7 @@ //! //! let moonshot_model = client.completion_model(moonshot::MOONSHOT_CHAT); //! ``` -use crate::client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, -}; +use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}; use crate::http_client::HttpClientExt; use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; @@ -29,7 +27,7 @@ use serde_json::{Value, json}; // ================================================================ const MOONSHOT_API_BASE_URL: &str = "https://api.moonshot.cn/v1"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -72,7 +70,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -129,16 +127,10 @@ where ) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - Ok(http_client::Builder::new() - .method(method) - .uri(url) - .header("Authorization", auth_header)) - } - pub(crate) fn post(&self, path: &str) -> http_client::Result { - self.req(http_client::Method::POST, path) + http_client::with_bearer_auth( + http_client::Builder::new().method(method).uri(url), + &self.api_key, + ) } pub(crate) fn get(&self, path: &str) -> http_client::Result { @@ -246,7 +238,7 @@ enum ApiResponse { pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k"; #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { client: Client, pub model: String, } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index e6613ffe8..8921382dd 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -39,8 +39,7 @@ //! let extractor = client.extractor::("llama3.2"); //! ``` use crate::client::{ - ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, - VerifyError, + CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, }; use crate::completion::{GetTokenUsage, Usage}; use crate::http_client::{self, HttpClientExt}; @@ -62,12 +61,11 @@ use reqwest; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::{convert::TryFrom, str::FromStr}; -use url::Url; // ---------- Main Client ---------- const OLLAMA_API_BASE_URL: &str = "http://localhost:11434"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { base_url: &'a str, http_client: T, } @@ -107,7 +105,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone, Debug)] -pub struct Client { +pub struct Client { base_url: String, http_client: T, } diff --git a/rig-core/src/providers/openai/audio_generation.rs b/rig-core/src/providers/openai/audio_generation.rs index d0360b1cf..470462a1d 100644 --- a/rig-core/src/providers/openai/audio_generation.rs +++ b/rig-core/src/providers/openai/audio_generation.rs @@ -1,21 +1,22 @@ use crate::audio_generation::{ self, AudioGenerationError, AudioGenerationRequest, AudioGenerationResponse, }; +use crate::http_client::{self, HttpClientExt}; use crate::providers::openai::Client; -use bytes::Bytes; +use bytes::{Buf, Bytes}; use serde_json::json; pub const TTS_1: &str = "tts-1"; pub const TTS_1_HD: &str = "tts-1-hd"; #[derive(Clone)] -pub struct AudioGenerationModel { - client: Client, +pub struct AudioGenerationModel { + client: Client, pub model: String, } -impl AudioGenerationModel { - pub fn new(client: Client, model: &str) -> Self { +impl AudioGenerationModel { + pub fn new(client: Client, model: &str) -> Self { Self { client, model: model.to_string(), @@ -23,7 +24,10 @@ impl AudioGenerationModel { } } -impl audio_generation::AudioGenerationModel for AudioGenerationModel { +impl audio_generation::AudioGenerationModel for AudioGenerationModel +where + T: HttpClientExt + Clone, +{ type Response = Bytes; #[cfg_attr(feature = "worker", worker::send)] @@ -31,29 +35,36 @@ impl audio_generation::AudioGenerationModel for AudioGenerationModel { &self, request: AudioGenerationRequest, ) -> Result, AudioGenerationError> { - let request = json!({ + let body = serde_json::to_vec(&json!({ "model": self.model, "input": request.text, "voice": request.voice, "speed": request.speed, - }); + }))?; - let response = self + let req = self .client - .post("/audio/speech") - .json(&request) - .send() - .await?; + .post("/audio/speech")? + .body(body) + .map_err(http_client::Error::from)?; + + let response = self.client.send(req).await?; if !response.status().is_success() { + let status = response.status(); + let mut bytes: Bytes = response.into_body().await?; + let mut as_slice = Vec::new(); + bytes.copy_to_slice(&mut as_slice); + + let text: String = String::from_utf8_lossy(&as_slice).into(); + return Err(AudioGenerationError::ProviderError(format!( "{}: {}", - response.status(), - response.text().await? + status, text ))); } - let bytes = response.bytes().await?; + let bytes: Bytes = response.into_body().await?; Ok(AudioGenerationResponse { audio: bytes.to_vec(), diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index 2c254f249..33502a978 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -33,7 +33,7 @@ use serde::{Deserialize, Serialize}; // ================================================================ const OPENAI_API_BASE_URL: &str = "https://api.openai.com/v1"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -75,7 +75,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -126,18 +126,13 @@ where pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - Ok(http_client::Request::post(url).header("Authorization", auth_header)) + http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - Ok(http_client::Request::get(url).header("Authorization", auth_header)) + http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) } pub(crate) async fn send( @@ -145,7 +140,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, R: From + Send, { self.http_client.request(req).await @@ -153,13 +148,6 @@ where } impl Client { - pub(crate) async fn send_reqwest( - &self, - req: reqwest::Request, - ) -> reqwest::Result { - self.http_client.execute(req).await - } - pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); @@ -278,8 +266,8 @@ impl ImageGenerationClient for Client { } #[cfg(feature = "audio")] -impl AudioGenerationClient for Client { - type AudioGenerationModel = AudioGenerationModel; +impl AudioGenerationClient for Client { + type AudioGenerationModel = AudioGenerationModel; /// Create an audio generation model with the given name. /// /// # Example diff --git a/rig-core/src/providers/openai/completion/mod.rs b/rig-core/src/providers/openai/completion/mod.rs index 570e72593..fa91da95a 100644 --- a/rig-core/src/providers/openai/completion/mod.rs +++ b/rig-core/src/providers/openai/completion/mod.rs @@ -712,7 +712,7 @@ impl fmt::Display for Usage { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, diff --git a/rig-core/src/providers/openai/embedding.rs b/rig-core/src/providers/openai/embedding.rs index 6458aaf59..2d20d8179 100644 --- a/rig-core/src/providers/openai/embedding.rs +++ b/rig-core/src/providers/openai/embedding.rs @@ -46,7 +46,7 @@ pub struct EmbeddingData { } #[derive(Clone)] -pub struct EmbeddingModel { +pub struct EmbeddingModel { client: Client, pub model: String, ndims: usize, @@ -78,8 +78,7 @@ where let req = self .client - .post("/embeddings") - .map_err(|e| EmbeddingError::HttpError(e.into()))? + .post("/embeddings")? .body(body) .map_err(|e| EmbeddingError::HttpError(e.into()))?; diff --git a/rig-core/src/providers/openai/image_generation.rs b/rig-core/src/providers/openai/image_generation.rs index aa2b2c9d8..9a8e8b041 100644 --- a/rig-core/src/providers/openai/image_generation.rs +++ b/rig-core/src/providers/openai/image_generation.rs @@ -47,7 +47,7 @@ impl TryFrom } #[derive(Clone)] -pub struct ImageGenerationModel { +pub struct ImageGenerationModel { client: Client, /// Name of the model (e.g.: dall-e-2) pub model: String, diff --git a/rig-core/src/providers/openai/responses_api/mod.rs b/rig-core/src/providers/openai/responses_api/mod.rs index ee0b7a500..fc60a986f 100644 --- a/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig-core/src/providers/openai/responses_api/mod.rs @@ -630,7 +630,7 @@ impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionReque /// The completion model struct for OpenAI's response API. #[derive(Clone)] -pub struct ResponsesCompletionModel { +pub struct ResponsesCompletionModel { /// The OpenAI client pub(crate) client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) diff --git a/rig-core/src/providers/openai/transcription.rs b/rig-core/src/providers/openai/transcription.rs index efa8615ad..b84f6f12a 100644 --- a/rig-core/src/providers/openai/transcription.rs +++ b/rig-core/src/providers/openai/transcription.rs @@ -1,4 +1,4 @@ -use crate::http_client::{self, HttpClientExt}; +use crate::http_client; use crate::providers::openai::{ApiResponse, Client}; use crate::transcription; use crate::transcription::TranscriptionError; @@ -29,7 +29,7 @@ impl TryFrom } #[derive(Clone)] -pub struct TranscriptionModel { +pub struct TranscriptionModel { client: Client, /// Name of the model (e.g.: gpt-3.5-turbo-1106) pub model: String, diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs index 797de2ae5..9233939a1 100644 --- a/rig-core/src/providers/openrouter/client.rs +++ b/rig-core/src/providers/openrouter/client.rs @@ -1,5 +1,5 @@ use crate::{ - client::{ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError}, + client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}, http_client::{self, HttpClientExt}, impl_conversion_traits, }; @@ -13,7 +13,7 @@ use super::completion::CompletionModel; // ================================================================ const OPENROUTER_API_BASE_URL: &str = "https://openrouter.ai/api/v1"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -56,7 +56,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -115,22 +115,10 @@ where } impl Client { - pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - Ok(http_client::Request::post(url).header("Authorization", auth_header)) - } - pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - Ok(http_client::Request::get(url).header("Authorization", auth_header)) + http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) } } diff --git a/rig-core/src/providers/openrouter/completion.rs b/rig-core/src/providers/openrouter/completion.rs index 666771926..ecfcf83de 100644 --- a/rig-core/src/providers/openrouter/completion.rs +++ b/rig-core/src/providers/openrouter/completion.rs @@ -122,7 +122,7 @@ pub struct Choice { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, /// Name of the model (e.g.: deepseek-ai/DeepSeek-R1) pub model: String, diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index dde78dbd1..2e6d25bba 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -12,11 +12,10 @@ use crate::{ OneOrMany, client::{VerifyClient, VerifyError}, completion::{self, CompletionError, MessageError, message}, - http_client::{self, HttpClientExt}, - impl_conversion_traits, json_utils, + http_client, impl_conversion_traits, json_utils, }; -use crate::client::{ClientBuilderError, CompletionClient, ProviderClient}; +use crate::client::{CompletionClient, ProviderClient}; use crate::completion::CompletionRequest; use crate::json_utils::merge; use crate::providers::openai; @@ -30,7 +29,7 @@ use serde_json::{Value, json}; // ================================================================ const PERPLEXITY_API_BASE_URL: &str = "https://api.perplexity.ai"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -73,7 +72,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -119,19 +118,6 @@ where } } -impl Client -where - T: HttpClientExt, -{ - pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - Ok(http_client::Request::post(url).header("Authorization", auth_header)) - } -} - impl Client { fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); diff --git a/rig-core/src/providers/together/client.rs b/rig-core/src/providers/together/client.rs index 647359d5a..29e9b93f9 100644 --- a/rig-core/src/providers/together/client.rs +++ b/rig-core/src/providers/together/client.rs @@ -11,7 +11,7 @@ use rig::client::CompletionClient; // ================================================================ const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -60,7 +60,7 @@ impl<'a, T> ClientBuilder<'a, T> { } } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, default_headers: reqwest::header::HeaderMap, api_key: String, @@ -117,15 +117,13 @@ where tracing::debug!("POST {}", url); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - let mut req = http_client::Request::post(url).header("Authorization", auth_header); + let mut req = http_client::Request::post(url); if let Some(hs) = req.headers_mut() { *hs = self.default_headers.clone(); } - Ok(req) + http_client::with_bearer_auth(req, &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { @@ -133,16 +131,13 @@ where tracing::debug!("GET {}", url); - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - let mut req = http_client::Request::get(url).header("Authorization", auth_header); + let mut req = http_client::Request::get(url); if let Some(hs) = req.headers_mut() { *hs = self.default_headers.clone(); } - Ok(req) + http_client::with_bearer_auth(req, &self.api_key) } pub(crate) async fn send( @@ -150,7 +145,7 @@ where req: http_client::Request, ) -> http_client::Result>> where - U: Into, + U: Into + Send, R: From + Send, { self.http_client.request(req).await @@ -158,10 +153,6 @@ where } impl Client { - pub(crate) fn reqwest_client(&self) -> &reqwest::Client { - &self.http_client - } - pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); diff --git a/rig-core/src/providers/together/completion.rs b/rig-core/src/providers/together/completion.rs index a07445d47..7c858fb15 100644 --- a/rig-core/src/providers/together/completion.rs +++ b/rig-core/src/providers/together/completion.rs @@ -128,7 +128,7 @@ pub const WIZARDLM_13B_V1_2: &str = "WizardLM/WizardLM-13B-V1.2"; // ================================================================= #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, pub model: String, } diff --git a/rig-core/src/providers/together/embedding.rs b/rig-core/src/providers/together/embedding.rs index 9efb8a4f9..c1c025c3e 100644 --- a/rig-core/src/providers/together/embedding.rs +++ b/rig-core/src/providers/together/embedding.rs @@ -66,7 +66,7 @@ pub struct Usage { } #[derive(Clone)] -pub struct EmbeddingModel { +pub struct EmbeddingModel { client: Client, pub model: String, ndims: usize, diff --git a/rig-core/src/providers/voyageai.rs b/rig-core/src/providers/voyageai.rs index dd5630913..9dd969016 100644 --- a/rig-core/src/providers/voyageai.rs +++ b/rig-core/src/providers/voyageai.rs @@ -1,8 +1,5 @@ -use crate::client::{ - ClientBuilderError, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError, -}; +use crate::client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError}; use crate::embeddings::EmbeddingError; -use crate::http_client::HttpClientExt; use crate::{embeddings, http_client, impl_conversion_traits}; use serde::Deserialize; use serde_json::json; @@ -12,7 +9,7 @@ use serde_json::json; // ================================================================ const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -55,7 +52,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, http_client: T, @@ -101,20 +98,6 @@ where } } -impl Client -where - T: HttpClientExt, -{ - pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", &self.api_key)) - .map_err(http::Error::from)?; - - Ok(http_client::Request::post(url).header("Authorization", auth_header)) - } -} - impl Client { pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs index ba970eed6..b45a25112 100644 --- a/rig-core/src/providers/xai/client.rs +++ b/rig-core/src/providers/xai/client.rs @@ -1,9 +1,6 @@ use super::completion::CompletionModel; use crate::{ - client::{ - ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError, - impl_conversion_traits, - }, + client::{CompletionClient, ProviderClient, VerifyClient, VerifyError, impl_conversion_traits}, http_client, }; @@ -12,7 +9,7 @@ use crate::{ // ================================================================ const XAI_BASE_URL: &str = "https://api.x.ai"; -pub struct ClientBuilder<'a, T> { +pub struct ClientBuilder<'a, T = reqwest::Client> { api_key: &'a str, base_url: &'a str, http_client: T, @@ -62,7 +59,7 @@ impl<'a, T> ClientBuilder<'a, T> { } #[derive(Clone)] -pub struct Client { +pub struct Client { base_url: String, api_key: String, default_headers: http_client::HeaderMap, @@ -110,42 +107,6 @@ where } } -impl Client { - fn req( - &self, - method: http_client::Method, - url: &str, - ) -> http_client::Result { - let mut request = http_client::Builder::new().method(method).uri(url); - - let auth_header = http_client::HeaderValue::from_str(&format!("Bearer {}", self.api_key)) - .map_err(|e| http_client::Error::Protocol(e.into()))?; - - if let Some(hs) = request.headers_mut() { - *hs = self.default_headers.clone(); - hs.insert("Authorization", auth_header); - } - - Ok(request) - } - - pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - - tracing::debug!("POST {}", url); - - self.req(http_client::Method::POST, &url) - } - - pub(crate) fn get(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - - tracing::debug!("GET {}", url); - - self.req(http_client::Method::GET, &url) - } -} - impl Client { pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); @@ -161,7 +122,7 @@ impl Client { pub(crate) fn reqwest_get(&self, path: &str) -> reqwest::RequestBuilder { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - tracing::debug!("POST {}", url); + tracing::debug!("GET {}", url); self.http_client .get(url) diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index be9922165..fe683be5d 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -31,7 +31,7 @@ pub const GROK_4: &str = "grok-4-0709"; // ================================================================= #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel { pub(crate) client: Client, pub model: String, } diff --git a/rig-eternalai/src/providers/eternalai.rs b/rig-eternalai/src/providers/eternalai.rs index cf30c80f8..3a95082a0 100644 --- a/rig-eternalai/src/providers/eternalai.rs +++ b/rig-eternalai/src/providers/eternalai.rs @@ -19,6 +19,7 @@ use rig::completion::GetTokenUsage; use rig::completion::{CompletionError, CompletionRequest}; use rig::embeddings::{EmbeddingError, EmbeddingsBuilder}; use rig::extractor::ExtractorBuilder; +use rig::http_client; use rig::message; use rig::message::AssistantContent; use rig::providers::openai::{self, Message}; @@ -317,10 +318,15 @@ impl embeddings::EmbeddingModel for EmbeddingModel { "input": documents, })) .send() - .await?; + .await + .map_err(|e| EmbeddingError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| EmbeddingError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "EternalAI embedding token usage: {}", @@ -346,7 +352,11 @@ impl embeddings::EmbeddingModel for EmbeddingModel { ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), } } else { - Err(EmbeddingError::ProviderError(response.text().await?)) + Err(EmbeddingError::ProviderError( + response.text().await.map_err(|e| { + EmbeddingError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } } @@ -631,10 +641,15 @@ impl completion::CompletionModel for CompletionModel { }, ) .send() - .await?; + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; if response.status().is_success() { - match response.json::>().await? { + match response + .json::>() + .await + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? + { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "EternalAI completion token usage: {:?}", @@ -654,7 +669,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } } diff --git a/rig-lancedb/tests/integration_tests.rs b/rig-lancedb/tests/integration_tests.rs index 6d6d86a91..f0d9cf106 100644 --- a/rig-lancedb/tests/integration_tests.rs +++ b/rig-lancedb/tests/integration_tests.rs @@ -105,8 +105,7 @@ async fn vector_search_test() { // Initialize OpenAI client let openai_client = openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); // Select an embedding model. let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-mongodb/tests/integration_tests.rs b/rig-mongodb/tests/integration_tests.rs index 21c4bf105..79ac73e0f 100644 --- a/rig-mongodb/tests/integration_tests.rs +++ b/rig-mongodb/tests/integration_tests.rs @@ -114,8 +114,7 @@ async fn vector_search_test() { // Initialize OpenAI client let openai_client = openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); @@ -220,8 +219,7 @@ async fn insert_documents_test() { // Initialize OpenAI client let openai_client = openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // Setup MongoDB container diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index 64cb3c0c2..894fbe03d 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -123,8 +123,7 @@ async fn vector_search_test() { // Initialize OpenAI client let openai_client = openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); From 874b731acd3f8b2360c706663fa9c4432ac968a2 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Mon, 6 Oct 2025 16:57:41 -0400 Subject: [PATCH 07/20] Merged main --- CONTRIBUTING.md | 9 +- Cargo.lock | 362 +++++++++++---- Cargo.toml | 2 +- README.md | 11 +- rig-bedrock/CHANGELOG.md | 10 + rig-bedrock/Cargo.toml | 4 +- rig-bedrock/src/types/document.rs | 37 +- rig-core/CHANGELOG.md | 30 ++ rig-core/Cargo.toml | 9 +- rig-core/README.md | 22 +- rig-core/examples/agent_with_echochambers.rs | 4 +- rig-core/examples/agent_with_tools_otel.rs | 166 +++++++ rig-core/examples/calculator_chatbot.rs | 4 +- rig-core/examples/dyn_client_streaming.rs | 156 +++++++ rig-core/examples/multi_agent.rs | 4 +- .../openai_agent_completions_api_otel.rs | 60 +++ .../openai_streaming_with_tools_otel.rs | 154 +++++++ rig-core/examples/otel/Dockerfile | 6 + rig-core/examples/otel/config.yaml | 28 ++ rig-core/examples/pdf_agent.rs | 4 +- rig-core/src/agent/builder.rs | 20 + rig-core/src/agent/completion.rs | 15 +- rig-core/src/agent/prompt_request/mod.rs | 107 ++++- .../src/agent/prompt_request/streaming.rs | 422 +++++++++++------- rig-core/src/agent/tool.rs | 20 +- rig-core/src/cli_chatbot.rs | 331 +++++++------- rig-core/src/client/builder.rs | 170 +++++++ rig-core/src/client/mod.rs | 1 + rig-core/src/completion/message.rs | 80 +++- rig-core/src/completion/request.rs | 27 ++ rig-core/src/extractor.rs | 54 ++- rig-core/src/lib.rs | 2 + .../src/providers/anthropic/completion.rs | 253 ++++++++--- rig-core/src/providers/anthropic/streaming.rs | 65 ++- rig-core/src/providers/azure.rs | 118 +++-- rig-core/src/providers/cohere/completion.rs | 140 ++++-- rig-core/src/providers/cohere/streaming.rs | 55 ++- rig-core/src/providers/deepseek.rs | 169 +++++-- rig-core/src/providers/galadriel.rs | 105 ++++- rig-core/src/providers/gemini/completion.rs | 280 ++++++++++-- rig-core/src/providers/gemini/streaming.rs | 46 ++ rig-core/src/providers/groq.rs | 150 +++++-- .../src/providers/huggingface/completion.rs | 142 +++++- .../src/providers/huggingface/streaming.rs | 27 +- rig-core/src/providers/hyperbolic.rs | 107 +++-- rig-core/src/providers/mira.rs | 128 ++++-- rig-core/src/providers/mistral/completion.rs | 136 +++++- rig-core/src/providers/moonshot.rs | 150 +++++-- rig-core/src/providers/ollama.rs | 154 +++++-- .../src/providers/openai/completion/mod.rs | 271 ++++++++--- .../providers/openai/completion/streaming.rs | 70 ++- .../src/providers/openai/responses_api/mod.rs | 220 +++++++-- .../openai/responses_api/streaming.rs | 215 +++++---- rig-core/src/providers/openrouter/client.rs | 15 +- .../src/providers/openrouter/completion.rs | 131 ++++-- .../src/providers/openrouter/streaming.rs | 23 +- rig-core/src/providers/perplexity.rs | 113 +++-- rig-core/src/providers/together/completion.rs | 137 ++++-- rig-core/src/providers/together/streaming.rs | 26 +- rig-core/src/providers/xai/completion.rs | 78 +++- rig-core/src/providers/xai/streaming.rs | 25 +- rig-core/src/telemetry/mod.rs | 96 ++++ rig-core/src/tools/think.rs | 2 +- rig-eternalai/CHANGELOG.md | 6 + rig-eternalai/Cargo.toml | 4 +- rig-fastembed/CHANGELOG.md | 6 + rig-fastembed/Cargo.toml | 4 +- rig-helixdb/.gitignore | 1 + rig-helixdb/Cargo.toml | 21 + rig-helixdb/README.md | 35 ++ .../examples/helixdb-cfg/db/queries.hx | 7 + rig-helixdb/examples/helixdb-cfg/db/schema.hx | 4 + rig-helixdb/examples/helixdb-cfg/helix.toml | 15 + rig-helixdb/examples/vector_search_helixdb.rs | 79 ++++ rig-helixdb/src/lib.rs | 181 ++++++++ rig-lancedb/CHANGELOG.md | 6 + rig-lancedb/Cargo.toml | 4 +- rig-milvus/CHANGELOG.md | 6 + rig-milvus/Cargo.toml | 4 +- rig-mongodb/CHANGELOG.md | 6 + rig-mongodb/Cargo.toml | 4 +- rig-neo4j/CHANGELOG.md | 6 + rig-neo4j/Cargo.toml | 4 +- rig-postgres/CHANGELOG.md | 6 + rig-postgres/Cargo.toml | 4 +- rig-qdrant/CHANGELOG.md | 6 + rig-qdrant/Cargo.toml | 4 +- rig-s3vectors/CHANGELOG.md | 6 + rig-s3vectors/Cargo.toml | 4 +- rig-scylladb/CHANGELOG.md | 6 + rig-scylladb/Cargo.toml | 4 +- rig-sqlite/CHANGELOG.md | 6 + rig-sqlite/Cargo.toml | 4 +- rig-surrealdb/CHANGELOG.md | 6 + rig-surrealdb/Cargo.toml | 4 +- rust-toolchain.toml | 2 + 96 files changed, 5205 insertions(+), 1238 deletions(-) create mode 100644 rig-core/examples/agent_with_tools_otel.rs create mode 100644 rig-core/examples/dyn_client_streaming.rs create mode 100644 rig-core/examples/openai_agent_completions_api_otel.rs create mode 100644 rig-core/examples/openai_streaming_with_tools_otel.rs create mode 100644 rig-core/examples/otel/Dockerfile create mode 100644 rig-core/examples/otel/config.yaml create mode 100644 rig-core/src/telemetry/mod.rs create mode 100644 rig-helixdb/.gitignore create mode 100644 rig-helixdb/Cargo.toml create mode 100644 rig-helixdb/README.md create mode 100644 rig-helixdb/examples/helixdb-cfg/db/queries.hx create mode 100644 rig-helixdb/examples/helixdb-cfg/db/schema.hx create mode 100644 rig-helixdb/examples/helixdb-cfg/helix.toml create mode 100644 rig-helixdb/examples/vector_search_helixdb.rs create mode 100644 rig-helixdb/src/lib.rs create mode 100644 rust-toolchain.toml diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 34d0c60b5..6dc45ac83 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,9 +31,16 @@ This will then auto-link issue 999 (for example) and will automatically close th ### Code Contribution Guidelines Most non-trivial open source projects often have a set of code contribution guidelines that are highly advised to stick to for the easiest path to a merge. Such policies also exist to ensure that the project is able to remain easy to contribute to. -While we will not strictly enforce any guidelines as such because we want to make it as easy as possible to contribute to Rig, we do have two policies that we advise contributors to stick to: +While we will not strictly enforce any guidelines as such because we want to make it as easy as possible to contribute to Rig, we do have three policies that we advise contributors to stick to: - Use docstrings on any new public items (structs, enums, methods whether free-standing or associated). - Ensure that you use full syntax for trait bounds where possible. This makes the code much easier to read. +- If your PR adds additional functionality to Rig, it must include relevant tests that pass (if the code does not directly interact with an API model provider), or alternatively an example that compiles if the code is user-facing. + +As a contributor, you are additionally welcome to use AI assistance for coding. However to make the review process as smooth as possible, it's helpful to keep in mind the following: +- You as a contributor are responsible for ensuring correctness, maintainability and compliance with project standards. Using AI does not change the quality bar. +- Please make it clear in the PR description that it was significantly or 100% generated by AI if it is the case. Adding a short note like "This PR was generated by Claude" to your PR description is generally sufficient. + +AI-generated PRs may require additional review to ensure correctness and long-term maintainability, including possible code style clean-up. Other than that, each PR will be taken on a case-by-case basis. diff --git a/Cargo.lock b/Cargo.lock index 037ac1f93..31f374429 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -151,9 +151,9 @@ checksum = "90c6333e01ba7235575b6ab53e5af10f1c327927fd97c36462917e289557ea64" [[package]] name = "anyhow" -version = "1.0.98" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +checksum = "b0674a1ddeecb70197781e945de4b3b8ffb61fa939a5597bcf48503737663100" [[package]] name = "approx" @@ -937,9 +937,9 @@ dependencies = [ [[package]] name = "aws-sdk-bedrockruntime" -version = "1.102.0" +version = "1.104.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2fce23c31248b7f46d3b346089ab087dd16bbb33b680ddb77a2a36331a1994" +checksum = "1574d1fad8f4bbf71aeb5dbb16653e7db48463f031ae77fdc161621019364d4a" dependencies = [ "aws-credential-types", "aws-runtime", @@ -1515,7 +1515,7 @@ dependencies = [ "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.12.1", + "itertools 0.10.5", "lazy_static", "lazycell", "log", @@ -1686,7 +1686,7 @@ dependencies = [ "serde_json", "serde_repr", "serde_urlencoded", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "tokio-util", "tower-service", @@ -3728,7 +3728,7 @@ checksum = "4005a505741dd1059db4c799a5370b9ea6c5eff3a19cb3475674670d73923d49" dependencies = [ "percent-encoding", "regex", - "thiserror 2.0.12", + "thiserror 2.0.16", "xml-rs", "zip 3.0.0", ] @@ -4892,6 +4892,33 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "helix-macros" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c07cf400a5ee47e259261d951c5615396ab1b518b1129ed562cfbdd60b014dff" +dependencies = [ + "dirs 5.0.1", + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "helix-rs" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d37e0d09decbce0bd0aeec46444449f84e87fae138bbc5180d6870b074f2708" +dependencies = [ + "anyhow", + "helix-macros", + "reqwest 0.12.23", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -4917,10 +4944,10 @@ dependencies = [ "log", "native-tls", "rand 0.9.1", - "reqwest 0.12.20", + "reqwest 0.12.23", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 2.0.16", "ureq", "windows-sys 0.60.2", ] @@ -5143,7 +5170,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.5.10", "tokio", "tower-service", "tracing", @@ -5267,7 +5294,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.5.10", "system-configuration 0.6.1", "tokio", "tower-service", @@ -5615,13 +5642,24 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "io-uring" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipconfig" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2", + "socket2 0.5.10", "widestring", "windows-sys 0.48.0", "winreg", @@ -6455,7 +6493,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.53.2", + "windows-targets 0.48.5", ] [[package]] @@ -6597,7 +6635,7 @@ dependencies = [ "rayon", "sha2", "stringprep", - "thiserror 2.0.12", + "thiserror 2.0.16", "time", "weezl", ] @@ -6971,7 +7009,7 @@ dependencies = [ "serde_with", "sha-1", "sha2", - "socket2", + "socket2 0.5.10", "stringprep", "strsim 0.11.1", "take_mut", @@ -7402,7 +7440,7 @@ dependencies = [ "percent-encoding", "quick-xml 0.37.5", "rand 0.8.5", - "reqwest 0.12.20", + "reqwest 0.12.23", "ring 0.17.14", "rustls-pemfile 2.2.0", "serde", @@ -7429,7 +7467,7 @@ dependencies = [ "itertools 0.14.0", "parking_lot", "percent-encoding", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "tracing", "url", @@ -7541,6 +7579,80 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "opentelemetry" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aaf416e4cb72756655126f7dd7bb0af49c674f4c1b9903e80c009e0c37e552e6" +dependencies = [ + "futures-core", + "futures-sink", + "js-sys", + "pin-project-lite", + "thiserror 2.0.16", + "tracing", +] + +[[package]] +name = "opentelemetry-http" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f6639e842a97dbea8886e3439710ae463120091e2e064518ba8e716e6ac36d" +dependencies = [ + "async-trait", + "bytes", + "http 1.3.1", + "opentelemetry", + "reqwest 0.12.23", +] + +[[package]] +name = "opentelemetry-otlp" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbee664a43e07615731afc539ca60c6d9f1a9425e25ca09c57bc36c87c55852b" +dependencies = [ + "http 1.3.1", + "opentelemetry", + "opentelemetry-http", + "opentelemetry-proto", + "opentelemetry_sdk", + "prost", + "reqwest 0.12.23", + "thiserror 2.0.16", + "tracing", +] + +[[package]] +name = "opentelemetry-proto" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e046fd7660710fe5a05e8748e70d9058dc15c94ba914e7c4faa7c728f0e8ddc" +dependencies = [ + "opentelemetry", + "opentelemetry_sdk", + "prost", + "tonic 0.13.1", +] + +[[package]] +name = "opentelemetry_sdk" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11f644aa9e5e31d11896e024305d7e3c98a88884d9f8919dbf37a9991bc47a4b" +dependencies = [ + "futures-channel", + "futures-executor", + "futures-util", + "opentelemetry", + "percent-encoding", + "rand 0.9.1", + "serde_json", + "thiserror 2.0.16", + "tokio", + "tokio-stream", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -7815,7 +7927,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1db05f56d34358a8b1066f67cbb203ee3e7ed2ba674a6263a1d5ec6db2204323" dependencies = [ "memchr", - "thiserror 2.0.12", + "thiserror 2.0.16", "ucd-trie", ] @@ -8311,7 +8423,7 @@ dependencies = [ "serde_json", "thiserror 1.0.69", "tokio", - "tonic", + "tonic 0.12.3", ] [[package]] @@ -8385,8 +8497,8 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls 0.23.28", - "socket2", - "thiserror 2.0.12", + "socket2 0.5.10", + "thiserror 2.0.16", "tokio", "tracing", "web-time", @@ -8407,7 +8519,7 @@ dependencies = [ "rustls 0.23.28", "rustls-pki-types", "slab", - "thiserror 2.0.12", + "thiserror 2.0.16", "tinyvec", "tracing", "web-time", @@ -8422,7 +8534,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.5.10", "tracing", "windows-sys 0.59.0", ] @@ -8711,7 +8823,7 @@ checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" dependencies = [ "getrandom 0.2.16", "libredox", - "thiserror 2.0.12", + "thiserror 2.0.16", ] [[package]] @@ -8836,13 +8948,14 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.20" +version = "0.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabf4c97d9130e2bf606614eb937e86edac8292eaa6f422f995d7e8de1eb1813" +checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.10", @@ -8895,7 +9008,7 @@ dependencies = [ "mime", "nom 7.1.3", "pin-project-lite", - "reqwest 0.12.20", + "reqwest 0.12.23", "thiserror 1.0.69", ] @@ -8975,7 +9088,7 @@ checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" [[package]] name = "rig-bedrock" -version = "0.3.1" +version = "0.3.2" dependencies = [ "anyhow", "async-stream", @@ -8983,8 +9096,8 @@ dependencies = [ "aws-sdk-bedrockruntime", "aws-smithy-types", "base64 0.22.1", - "reqwest 0.12.20", - "rig-core 0.20.0", + "reqwest 0.12.23", + "rig-core 0.21.0", "rig-derive", "schemars 1.0.4", "serde", @@ -9009,11 +9122,11 @@ dependencies = [ "glob", "mime_guess", "ordered-float", - "reqwest 0.12.20", + "reqwest 0.12.23", "schemars 0.8.22", "serde", "serde_json", - "thiserror 2.0.12", + "thiserror 2.0.16", "tracing", "url", "worker", @@ -9021,7 +9134,7 @@ dependencies = [ [[package]] name = "rig-core" -version = "0.20.0" +version = "0.21.0" dependencies = [ "anyhow", "as-any", @@ -9037,10 +9150,13 @@ dependencies = [ "hyper-util", "lopdf", "mime_guess", + "opentelemetry", + "opentelemetry-otlp", + "opentelemetry_sdk", "ordered-float", "quick-xml 0.38.0", "rayon", - "reqwest 0.12.20", + "reqwest 0.12.23", "reqwest-eventsource", "rig-derive", "rmcp", @@ -9048,10 +9164,12 @@ dependencies = [ "serde", "serde_json", "serde_path_to_error", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "tokio-test", "tracing", + "tracing-futures", + "tracing-opentelemetry", "tracing-subscriber", "url", "worker", @@ -9066,7 +9184,7 @@ dependencies = [ "indoc", "proc-macro2", "quote", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "syn 2.0.104", @@ -9076,13 +9194,13 @@ dependencies = [ [[package]] name = "rig-eternalai" -version = "0.3.7" +version = "0.3.8" dependencies = [ "anyhow", "async-stream", "ethers", - "reqwest 0.12.20", - "rig-core 0.20.0", + "reqwest 0.12.23", + "rig-core 0.21.0", "schemars 1.0.4", "serde", "serde_json", @@ -9093,11 +9211,11 @@ dependencies = [ [[package]] name = "rig-fastembed" -version = "0.2.11" +version = "0.2.12" dependencies = [ "anyhow", "fastembed", - "rig-core 0.20.0", + "rig-core 0.21.0", "schemars 1.0.4", "serde", "serde_json", @@ -9105,9 +9223,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "rig-helixdb" +version = "0.1.0" +dependencies = [ + "helix-rs", + "rig-core 0.21.0", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "rig-lancedb" -version = "0.2.22" +version = "0.2.23" dependencies = [ "anyhow", "arrow-array", @@ -9115,7 +9244,7 @@ dependencies = [ "futures", "httpmock", "lancedb", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "tokio", @@ -9123,11 +9252,11 @@ dependencies = [ [[package]] name = "rig-milvus" -version = "0.1.11" +version = "0.1.12" dependencies = [ "anyhow", - "reqwest 0.12.20", - "rig-core 0.20.0", + "reqwest 0.12.23", + "rig-core 0.21.0", "serde", "serde_json", "tokio", @@ -9137,13 +9266,13 @@ dependencies = [ [[package]] name = "rig-mongodb" -version = "0.2.22" +version = "0.2.23" dependencies = [ "anyhow", "futures", "httpmock", "mongodb", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "testcontainers", @@ -9154,13 +9283,13 @@ dependencies = [ [[package]] name = "rig-neo4j" -version = "0.3.6" +version = "0.3.7" dependencies = [ "anyhow", "futures", "httpmock", "neo4rs", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "term_size", @@ -9173,14 +9302,14 @@ dependencies = [ [[package]] name = "rig-postgres" -version = "0.1.20" +version = "0.1.21" dependencies = [ "anyhow", "dotenvy", "httpmock", "log", "pgvector", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "sqlx", @@ -9194,12 +9323,12 @@ dependencies = [ [[package]] name = "rig-qdrant" -version = "0.1.25" +version = "0.1.26" dependencies = [ "anyhow", "httpmock", "qdrant-client", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "testcontainers", @@ -9209,13 +9338,13 @@ dependencies = [ [[package]] name = "rig-s3vectors" -version = "0.1.8" +version = "0.1.9" dependencies = [ "anyhow", "aws-config", "aws-sdk-s3vectors", "aws-smithy-types", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "tokio", @@ -9226,13 +9355,13 @@ dependencies = [ [[package]] name = "rig-scylladb" -version = "0.1.11" +version = "0.1.12" dependencies = [ "anyhow", "chrono", "futures", "httpmock", - "rig-core 0.20.0", + "rig-core 0.21.0", "scylla", "serde", "serde_json", @@ -9245,12 +9374,12 @@ dependencies = [ [[package]] name = "rig-sqlite" -version = "0.1.22" +version = "0.1.23" dependencies = [ "anyhow", "chrono", "httpmock", - "rig-core 0.20.0", + "rig-core 0.21.0", "rusqlite", "serde", "serde_json", @@ -9264,10 +9393,10 @@ dependencies = [ [[package]] name = "rig-surrealdb" -version = "0.1.17" +version = "0.1.18" dependencies = [ "anyhow", - "rig-core 0.20.0", + "rig-core 0.21.0", "serde", "serde_json", "surrealdb", @@ -9404,13 +9533,13 @@ dependencies = [ "paste", "pin-project-lite", "rand 0.9.1", - "reqwest 0.12.20", + "reqwest 0.12.23", "rmcp-macros", "schemars 1.0.4", "serde", "serde_json", "sse-stream", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "tokio-stream", "tokio-util", @@ -9934,8 +10063,8 @@ dependencies = [ "rand_pcg", "scylla-cql", "smallvec", - "socket2", - "thiserror 2.0.12", + "socket2 0.5.10", + "thiserror 2.0.16", "tokio", "tracing", "uuid 1.17.0", @@ -9955,7 +10084,7 @@ dependencies = [ "scylla-macros", "snap", "stable_deref_trait", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "uuid 1.17.0", "yoke 0.7.5", @@ -10118,9 +10247,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "indexmap 2.10.0", "itoa", @@ -10333,7 +10462,7 @@ checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb" dependencies = [ "num-bigint", "num-traits", - "thiserror 2.0.12", + "thiserror 2.0.16", "time", ] @@ -10429,6 +10558,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "socks" version = "0.3.4" @@ -10574,7 +10713,7 @@ dependencies = [ "serde_json", "sha2", "smallvec", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "tokio-stream", "tracing", @@ -10657,7 +10796,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 2.0.12", + "thiserror 2.0.16", "tracing", "uuid 1.17.0", "whoami", @@ -10695,7 +10834,7 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 2.0.12", + "thiserror 2.0.16", "tracing", "uuid 1.17.0", "whoami", @@ -10720,7 +10859,7 @@ dependencies = [ "serde", "serde_urlencoded", "sqlx-core", - "thiserror 2.0.12", + "thiserror 2.0.16", "tracing", "url", "uuid 1.17.0", @@ -10911,7 +11050,7 @@ dependencies = [ "path-clean", "pharos", "reblessive", - "reqwest 0.12.20", + "reqwest 0.12.23", "revision 0.11.0", "ring 0.17.14", "rust_decimal", @@ -11453,7 +11592,7 @@ dependencies = [ "serde", "serde_json", "serde_with", - "thiserror 2.0.12", + "thiserror 2.0.16", "tokio", "tokio-stream", "tokio-tar", @@ -11483,11 +11622,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" dependencies = [ - "thiserror-impl 2.0.12", + "thiserror-impl 2.0.16", ] [[package]] @@ -11503,9 +11642,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" dependencies = [ "proc-macro2", "quote", @@ -11624,7 +11763,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror 2.0.12", + "thiserror 2.0.16", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", @@ -11632,20 +11771,22 @@ dependencies = [ [[package]] name = "tokio" -version = "1.45.1" +version = "1.47.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "slab", + "socket2 0.6.0", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -11860,7 +12001,7 @@ dependencies = [ "prost", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", - "socket2", + "socket2 0.5.10", "tokio", "tokio-rustls 0.26.2", "tokio-stream", @@ -11870,6 +12011,27 @@ dependencies = [ "tracing", ] +[[package]] +name = "tonic" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "percent-encoding", + "pin-project", + "prost", + "tokio-stream", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.4.13" @@ -11975,6 +12137,8 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" dependencies = [ + "futures", + "futures-task", "pin-project", "tracing", ] @@ -11990,6 +12154,24 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-opentelemetry" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddcf5959f39507d0d04d6413119c04f33b623f4f951ebcbdddddfad2d0623a9c" +dependencies = [ + "js-sys", + "once_cell", + "opentelemetry", + "opentelemetry_sdk", + "smallvec", + "tracing", + "tracing-core", + "tracing-log", + "tracing-subscriber", + "web-time", +] + [[package]] name = "tracing-subscriber" version = "0.3.19" @@ -12696,7 +12878,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] @@ -13179,7 +13361,7 @@ dependencies = [ "serde", "serde-wasm-bindgen", "serde_json", - "thiserror 2.0.12", + "thiserror 2.0.16", "wasm-bindgen", "wasm-bindgen-futures", ] @@ -13231,7 +13413,7 @@ dependencies = [ "pharos", "rustc_version", "send_wrapper 0.6.0", - "thiserror 2.0.12", + "thiserror 2.0.16", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", diff --git a/Cargo.toml b/Cargo.toml index 9e43b2e13..7176353ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ "rig-bedrock", "rig-milvus", "rig-wasm", - "rig-s3vectors", + "rig-s3vectors", "rig-helixdb", ] [profile.release] diff --git a/README.md b/README.md index fbe7f5f62..568222912 100644 --- a/README.md +++ b/README.md @@ -56,10 +56,15 @@ Rig is a Rust library for building scalable, modular, and ergonomic **LLM-powere More information about this crate can be found in the [official](https://docs.rig.rs) & [crate](https://docs.rs/rig-core/latest/rig/) (API Reference) documentations. -## High-level features +## Features +- Agentic workflows that can handle multi-turn streaming and prompting +- Full [GenAI Semantic Convention](https://opentelemetry.io/docs/specs/semconv/gen-ai/) compatibility +- 20+ model providers, all under one singular unified interface +- 10+ vector store integrations, all under one singular unified interface - Full support for LLM completion and embedding workflows -- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, SQlite, in-memory) +- Support for transcription, audio generation and image generation model capabilities - Integrate LLMs in your app with minimal boilerplate +- Full WASM compatibility (core library only) ## Who is using Rig in production? Below is a non-exhaustive list of companies and people who are using Rig in production: @@ -70,6 +75,7 @@ Below is a non-exhaustive list of companies and people who are using Rig in prod - [rig-onchain-kit](https://github.com/0xPlaygrounds/rig-onchain-kit) - the Rig Onchain Kit. Intended to make interactions between Solana/EVM and Rig much easier to implement. - [Linera Protocol](https://github.com/linera-io/linera-protocol) - Decentralized blockchain infrastructure designed for highly scalable, secure, low-latency Web3 applications. - [Listen](https://github.com/piotrostr/listen) - A framework aiming to become the go-to framework for AI portfolio management agents. Powers [the Listen app.](https://app.listen-rs.com/) +- [VT Code](https://github.com/vinhnx/vtcode) - VT Code is a Rust-based terminal coding agent with semantic code intelligence via Tree-sitter and ast-grep. VT Code uses `rig` for simplifying LLM calls and implement model picker. Are you also using Rig in production? [Open an issue](https://www.github.com/0xPlaygrounds/rig/issues) to have your name added! @@ -116,6 +122,7 @@ Vector stores are available as separate companion-crates: - Milvus: [`rig-milvus`](https://github.com/0xPlaygrounds/rig/tree/main/rig-milvus) - ScyllaDB: [`rig-scylladb`](https://github.com/0xPlaygrounds/rig/tree/main/rig-scylladb) - AWS S3Vectors: [`rig-s3vectors`](https://github.com/0xPlaygrounds/rig/tree/main/rig-s3vectors) +- HelixDB: [`rig-helixdb`](https://github.com/0xPlaygrounds/rig/tree/main/rig-helixdb) The following providers are available as separate companion-crates: - Fastembed: [`rig-fastembed`](https://github.com/0xPlaygrounds/rig/tree/main/rig-fastembed) diff --git a/rig-bedrock/CHANGELOG.md b/rig-bedrock/CHANGELOG.md index 766bf61f6..306407520 100644 --- a/rig-bedrock/CHANGELOG.md +++ b/rig-bedrock/CHANGELOG.md @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.2](https://github.com/0xPlaygrounds/rig/compare/rig-bedrock-v0.3.1...rig-bedrock-v0.3.2) - 2025-09-29 + +### Added + +- *(rig-795)* support file URLs for audio, video, documents ([#823](https://github.com/0xPlaygrounds/rig/pull/823)) + +### Other + +- *(rig-963)* fix feature regression in AWS bedrock ([#863](https://github.com/0xPlaygrounds/rig/pull/863)) + ## [0.3.1](https://github.com/0xPlaygrounds/rig/compare/rig-bedrock-v0.3.0...rig-bedrock-v0.3.1) - 2025-09-15 ### Added diff --git a/rig-bedrock/Cargo.toml b/rig-bedrock/Cargo.toml index f9d83e5c7..57f903ebc 100644 --- a/rig-bedrock/Cargo.toml +++ b/rig-bedrock/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-bedrock" -version = "0.3.1" +version = "0.3.2" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -12,7 +12,7 @@ aws-config = { workspace = true, features = ["behavior-version-latest"] } aws-sdk-bedrockruntime = { workspace = true } aws-smithy-types = { workspace = true } base64 = { workspace = true } -rig-core = { path = "../rig-core", version = "0.20.0", features = ["image"] } +rig-core = { path = "../rig-core", version = "0.21.0", features = ["image"] } rig-derive = { path = "../rig-core/rig-core-derive", version = "0.1.6" } schemars = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/rig-bedrock/src/types/document.rs b/rig-bedrock/src/types/document.rs index 1b6a6bad6..8b85384c0 100644 --- a/rig-bedrock/src/types/document.rs +++ b/rig-bedrock/src/types/document.rs @@ -27,18 +27,26 @@ impl TryFrom for aws_bedrock::DocumentBlock { None => Ok(None), }?; - let DocumentSourceKind::Base64(data) = data else { - return Err(CompletionError::RequestError( - "Invalid document format".into(), - )); - }; - - let data = BASE64_STANDARD - .decode(data) - .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let document_source = match data { + DocumentSourceKind::Base64(blob) => { + let bytes = BASE64_STANDARD + .decode(blob) + .map_err(|e| CompletionError::RequestError(e.into()))?; - let data = aws_smithy_types::Blob::new(data); - let document_source = aws_bedrock::DocumentSource::Bytes(data); + aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new(bytes)) + } + // NOTE: until [aws-sdk-bedrockruntime DocumentSource bug #1365](https://github.com/awslabs/aws-sdk-rust/issues/1365) + // is resolved we will use this as a workaround + // DocumentSourceKind::String(str) => aws_bedrock::DocumentSource::Text(str), + DocumentSourceKind::String(str) => { + aws_bedrock::DocumentSource::Bytes(aws_smithy_types::Blob::new(str.as_bytes())) + } + doc => { + return Err(CompletionError::RequestError( + format!("Unsupported document kind: {doc}").into(), + )); + } + }; let random_string = Uuid::new_v4().simple().to_string(); let document_name = format!("document-{random_string}"); @@ -64,9 +72,10 @@ impl TryFrom for RigDocument { let encoded_data = BASE64_STANDARD.encode(blob.into_inner()); Ok(DocumentSourceKind::Base64(encoded_data)) } - _ => Err(CompletionError::ProviderError( - "Document source is missing".into(), - )), + Some(aws_bedrock::DocumentSource::Text(str)) => Ok(DocumentSourceKind::String(str)), + doc => Err(CompletionError::ProviderError(format!( + "Unsupported document type: {doc:?}" + ))), }?; Ok(RigDocument(Document { diff --git a/rig-core/CHANGELOG.md b/rig-core/CHANGELOG.md index 35d8dc7dc..0b52beb5f 100644 --- a/rig-core/CHANGELOG.md +++ b/rig-core/CHANGELOG.md @@ -7,6 +7,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.21.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.20.0...rig-core-v0.21.0) - 2025-09-29 + +### Added + +- GenAI SemConv support (otel) ([#850](https://github.com/0xPlaygrounds/rig/pull/850)) +- add streaming support to DynClientBuilder ([#824](https://github.com/0xPlaygrounds/rig/pull/824)) +- *(rig-912)* rework `Chat` trait for multi-turn ([#846](https://github.com/0xPlaygrounds/rig/pull/846)) +- *(rig-795)* support file URLs for audio, video, documents ([#823](https://github.com/0xPlaygrounds/rig/pull/823)) +- *(rig-943)* support thinking for cohere ([#827](https://github.com/0xPlaygrounds/rig/pull/827)) + +### Fixed + +- only youtube videos should accept null mime type (gemini) ([#873](https://github.com/0xPlaygrounds/rig/pull/873)) +- *(rig-970)* file URLs should be able to accept empty media type (Gemini) ([#872](https://github.com/0xPlaygrounds/rig/pull/872)) +- *(rig-970)* youtube video ingestion doesn't work (gemini) +- fix(rig-962)(deepseek): tool calls not recognised when put behind text content ([#862](https://github.com/0xPlaygrounds/rig/pull/862)) +- fix-853 ([#854](https://github.com/0xPlaygrounds/rig/pull/854)) +- *(rig-956)* DocumentSourceKind fails to serialize with common serializers ([#849](https://github.com/0xPlaygrounds/rig/pull/849)) +- *(rig-957)* huggingface should convert image URLs ([#848](https://github.com/0xPlaygrounds/rig/pull/848)) +- *(rig-950)* openai imagegen doesn't work with gpt-image-1 ([#837](https://github.com/0xPlaygrounds/rig/pull/837)) +- ci lints ([#832](https://github.com/0xPlaygrounds/rig/pull/832)) + +### Other + +- *(rig-969)* update features on README ([#870](https://github.com/0xPlaygrounds/rig/pull/870)) +- *(rig-963)* fix feature regression in AWS bedrock ([#863](https://github.com/0xPlaygrounds/rig/pull/863)) +- fix typo in comment ([#866](https://github.com/0xPlaygrounds/rig/pull/866)) +- parse NDJSON correctly, fixes #825 ([#826](https://github.com/0xPlaygrounds/rig/pull/826)) +- make Reasoning non-exhaustive ([#830](https://github.com/0xPlaygrounds/rig/pull/830)) + ## [0.20.0](https://github.com/0xPlaygrounds/rig/compare/rig-core-v0.19.0...rig-core-v0.20.0) - 2025-09-15 ### Added diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 97ef558e0..fc498376b 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-core" -version = "0.20.0" +version = "0.21.0" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -44,6 +44,7 @@ rmcp = { version = "0.6", optional = true, features = ["client"] } reqwest-eventsource = { workspace = true } tokio = { workspace = true, features = ["sync"] } http = "1.3.1" +tracing-futures = { version = "0.2.5", features = ["futures-03"] } [dev-dependencies] anyhow = { workspace = true } @@ -68,6 +69,12 @@ rmcp = { version = "0.6", features = [ ] } axum = "0.8.4" +# required for otel +opentelemetry = "0.30.0" +opentelemetry_sdk = { version = "0.30.0", features = ["rt-tokio"] } +opentelemetry-otlp = "0.30.0" +tracing-opentelemetry = "0.31.0" + [features] default = ["reqwest/default"] diff --git a/rig-core/README.md b/rig-core/README.md index 7ce7114f6..1fa3ec16a 100644 --- a/rig-core/README.md +++ b/rig-core/README.md @@ -10,11 +10,17 @@ More information about this crate can be found in the [crate documentation](http - [Installation](#installation) - [Simple example:](#simple-example) - [Integrations](#integrations) + - [Who is using Rig in production?](#who-is-using-rig-in-production) -## High-level features +## Features +- Agentic workflows that can handle multi-turn streaming and prompting +- Full [GenAI Semantic Convention](https://opentelemetry.io/docs/specs/semconv/gen-ai/) compatibility +- 20+ model providers, all under one singular unified interface +- 10+ vector store integrations, all under one singular unified interface - Full support for LLM completion and embedding workflows -- Simple but powerful common abstractions over LLM providers (e.g. OpenAI, Cohere) and vector stores (e.g. MongoDB, SQLite, in-memory) +- Support for transcription, audio generation and image generation model capabilities - Integrate LLMs in your app with minimal boilerplate +- Full WASM compatibility (core library only) ## Installation ```bash @@ -84,3 +90,15 @@ The following providers are available as separate companion-crates: - Fastembed: [`rig-fastembed`](https://github.com/0xPlaygrounds/rig/tree/main/rig-fastembed) - Eternal AI: [`rig-eternalai`](https://github.com/0xPlaygrounds/rig/tree/main/rig-eternalai) + +## Who is using Rig in production? +Below is a non-exhaustive list of companies and people who are using Rig in production: +- [Dria Compute Node](https://github.com/firstbatchxyz/dkn-compute-node) - a node that serves computation results within the Dria Knowledge Network +- [The MCP Rust SDK](https://github.com/modelcontextprotocol/rust-sdk ) - the official Model Context Protocol Rust SDK. Has an example for usage with Rig. +- [Probe](https://github.com/buger/probe) - an AI-friendly, fully local semantic code search tool. +- [NINE](https://github.com/NethermindEth/nine) - Neural Interconnected Nodes Engine, by [Nethermind.](https://www.nethermind.io/) +- [rig-onchain-kit](https://github.com/0xPlaygrounds/rig-onchain-kit) - the Rig Onchain Kit. Intended to make interactions between Solana/EVM and Rig much easier to implement. +- [Linera Protocol](https://github.com/linera-io/linera-protocol) - Decentralized blockchain infrastructure designed for highly scalable, secure, low-latency Web3 applications. +- [Listen](https://github.com/piotrostr/listen) - A framework aiming to become the go-to framework for AI portfolio management agents. Powers [the Listen app.](https://app.listen-rs.com/) + +Are you also using Rig in production? [Open an issue](https://www.github.com/0xPlaygrounds/rig/issues) to have your name added! diff --git a/rig-core/examples/agent_with_echochambers.rs b/rig-core/examples/agent_with_echochambers.rs index 1f7cb2956..19e29e96a 100644 --- a/rig-core/examples/agent_with_echochambers.rs +++ b/rig-core/examples/agent_with_echochambers.rs @@ -2,7 +2,7 @@ use anyhow::Result; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderValue}; use rig::prelude::*; use rig::{ - cli_chatbot::ChatbotBuilder, + cli_chatbot::ChatBotBuilder, completion::ToolDefinition, providers::openai::{Client, GPT_4O}, tool::Tool, @@ -361,7 +361,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); // Build a CLI chatbot from the agent, with multi-turn enabled - let chatbot = ChatbotBuilder::new() + let chatbot = ChatBotBuilder::new() .agent(echochambers_agent) .multi_turn_depth(10) .build(); diff --git a/rig-core/examples/agent_with_tools_otel.rs b/rig-core/examples/agent_with_tools_otel.rs new file mode 100644 index 000000000..3394d643b --- /dev/null +++ b/rig-core/examples/agent_with_tools_otel.rs @@ -0,0 +1,166 @@ +//! Agent multi-turn with tools, but with a tracing subscriber that sends all logs/traces to an OTel collector. +//! Note that if the tool runs too fast, a given observability platform may put traces in the wrong order +//! hence the delay. +//! +//! In production, this is very unlikely to be a problem as many of the tools used may include MCP servers and other long-running +//! operations, which may cause issues. +use std::time::Duration; + +use anyhow::Result; +use opentelemetry::trace::TracerProvider; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::Resource; +use opentelemetry_sdk::trace::SdkTracerProvider; +use rig::prelude::*; +use rig::{ + completion::{Prompt, ToolDefinition}, + providers, + tool::Tool, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tracing::Level; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"], + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + println!("[tool-call] Adding {} and {}", args.x, args.y); + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; + +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"], + }, + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + println!("[tool-call] Subtracting {} from {}", args.y, args.x); + let result = args.x - args.y; + // Sleep for 1 microsecond to allow simulating a more compute-heavy tool + // Tools with <1ms execution time can get mixed up in tracing order on + // observability backend platforms + tokio::time::sleep(Duration::from_micros(1)).await; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + let exporter = opentelemetry_otlp::SpanExporter::builder() + .with_http() + .with_protocol(opentelemetry_otlp::Protocol::HttpBinary) + .build()?; + // Create a new OpenTelemetry trace pipeline that prints to stdout + let provider = SdkTracerProvider::builder() + .with_batch_exporter(exporter) + .with_resource(Resource::builder().with_service_name("rig-demo").build()) + .build(); + let tracer = provider.tracer("readme_example"); + + // Create a tracing layer with the configured tracer + let otel_layer = tracing_opentelemetry::layer().with_tracer(tracer); + let filter_layer = tracing_subscriber::filter::EnvFilter::builder() + .with_default_directive(Level::INFO.into()) + .from_env_lossy(); + + let fmt_layer = tracing_subscriber::fmt::layer().pretty(); + + // Use the tracing subscriber `Registry`, or any other subscriber + // that impls `LookupSpan` + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .with(otel_layer) + .init(); + + // Create OpenAI client + let openai_client = providers::openai::Client::from_env(); + + // Create agent with a single context prompt and two tools + let calculator_agent = openai_client + .agent(providers::openai::GPT_4O) + .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + // Prompt the agent and print the response + println!("Calculate 2 - 5"); + + println!( + "OpenAI Calculator Agent: {}", + calculator_agent.prompt("Calculate 2 - 5").await? + ); + + let _ = provider.shutdown(); + + Ok(()) +} diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 9f5bad81a..68cabfe9f 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -1,7 +1,7 @@ use anyhow::Result; +use rig::cli_chatbot::ChatBotBuilder; use rig::prelude::*; use rig::{ - cli_chatbot::ChatbotBuilder, completion::ToolDefinition, embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -274,7 +274,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); // Create a CLI chatbot from the agent - let chatbot = ChatbotBuilder::new().agent(calculator_rag).build(); + let chatbot = ChatBotBuilder::new().agent(calculator_rag).build(); chatbot.run().await?; diff --git a/rig-core/examples/dyn_client_streaming.rs b/rig-core/examples/dyn_client_streaming.rs new file mode 100644 index 000000000..038b3e0d5 --- /dev/null +++ b/rig-core/examples/dyn_client_streaming.rs @@ -0,0 +1,156 @@ +use futures::StreamExt; +/// This example showcases using streaming with multiple clients by using a dynamic ClientBuilder. +/// In this example, we will use both OpenAI and Anthropic with streaming responses - so ensure you have your `OPENAI_API_KEY` and `ANTHROPIC_API_KEY` set when using this example! +use rig::{client::builder::DynClientBuilder, providers::anthropic::CLAUDE_3_7_SONNET}; + +#[tokio::main] +async fn main() { + let multi_client = DynClientBuilder::new(); + + // Test streaming with OpenAI + println!("=== Testing OpenAI Streaming ==="); + match test_openai_streaming(&multi_client).await { + Ok(_) => println!("OpenAI streaming test completed successfully"), + Err(e) => println!("OpenAI streaming test failed: {}", e), + } + + // Test streaming with Anthropic + println!("\n=== Testing Anthropic Streaming ==="); + match test_anthropic_streaming(&multi_client).await { + Ok(_) => println!("Anthropic streaming test completed successfully"), + Err(e) => println!("Anthropic streaming test failed: {}", e), + } + + // Test streaming with ProviderModelId + println!("\n=== Testing ProviderModelId Streaming ==="); + match test_provider_model_id_streaming(&multi_client).await { + Ok(_) => println!("ProviderModelId streaming test completed successfully"), + Err(e) => println!("ProviderModelId streaming test failed: {}", e), + } +} + +async fn test_openai_streaming( + client: &DynClientBuilder, +) -> Result<(), Box> { + println!( + "Streaming prompt to OpenAI (gpt-4o): 'Tell me a short story about a robot learning to paint'" + ); + + let mut stream = client + .stream_prompt( + "openai", + "gpt-4o", + "Tell me a short story about a robot learning to paint", + ) + .await?; + + print!("Response: "); + while let Some(chunk) = stream.next().await { + match chunk { + Ok(rig::streaming::StreamedAssistantContent::Text(text)) => { + print!("{}", text.text); + std::io::Write::flush(&mut std::io::stdout())?; + } + Ok(rig::streaming::StreamedAssistantContent::Reasoning(reasoning)) => { + println!("\n[Reasoning: {}]", reasoning.reasoning.join("")); + } + Ok(rig::streaming::StreamedAssistantContent::ToolCall(tool_call)) => { + println!("\n[Tool Call: {}]", tool_call.function.name); + } + Ok(rig::streaming::StreamedAssistantContent::Final(_)) => { + println!("\n[Stream completed]"); + break; + } + Err(e) => { + println!("\n[Error: {}]", e); + break; + } + } + } + println!(); + + Ok(()) +} + +async fn test_anthropic_streaming( + client: &DynClientBuilder, +) -> Result<(), Box> { + println!( + "Streaming prompt to Anthropic (Claude 3.7 Sonnet): 'Explain quantum computing in simple terms'" + ); + + let mut stream = client + .stream_prompt( + "anthropic", + CLAUDE_3_7_SONNET, + "Explain quantum computing in simple terms", + ) + .await?; + + print!("Response: "); + while let Some(chunk) = stream.next().await { + match chunk { + Ok(rig::streaming::StreamedAssistantContent::Text(text)) => { + print!("{}", text.text); + std::io::Write::flush(&mut std::io::stdout())?; + } + Ok(rig::streaming::StreamedAssistantContent::Reasoning(reasoning)) => { + println!("\n[Reasoning: {}]", reasoning.reasoning.join("")); + } + Ok(rig::streaming::StreamedAssistantContent::ToolCall(tool_call)) => { + println!("\n[Tool Call: {}]", tool_call.function.name); + } + Ok(rig::streaming::StreamedAssistantContent::Final(_)) => { + println!("\n[Stream completed]"); + break; + } + Err(e) => { + println!("\n[Error: {}]", e); + break; + } + } + } + println!(); + + Ok(()) +} + +async fn test_provider_model_id_streaming( + client: &DynClientBuilder, +) -> Result<(), Box> { + println!( + "Streaming prompt using ProviderModelId: 'What are the benefits of renewable energy?'" + ); + + let provider_model = client.id("openai:gpt-4o")?; + let mut stream = provider_model + .stream_prompt("What are the benefits of renewable energy?") + .await?; + + print!("Response: "); + while let Some(chunk) = stream.next().await { + match chunk { + Ok(rig::streaming::StreamedAssistantContent::Text(text)) => { + print!("{}", text.text); + std::io::Write::flush(&mut std::io::stdout())?; + } + Ok(rig::streaming::StreamedAssistantContent::Reasoning(reasoning)) => { + println!("\n[Reasoning: {}]", reasoning.reasoning.join("")); + } + Ok(rig::streaming::StreamedAssistantContent::ToolCall(tool_call)) => { + println!("\n[Tool Call: {}]", tool_call.function.name); + } + Ok(rig::streaming::StreamedAssistantContent::Final(_)) => { + println!("\n[Stream completed]"); + break; + } + Err(e) => { + println!("\n[Error: {}]", e); + break; + } + } + } + println!(); + + Ok(()) +} diff --git a/rig-core/examples/multi_agent.rs b/rig-core/examples/multi_agent.rs index f4dc0c65c..33e35f8de 100644 --- a/rig-core/examples/multi_agent.rs +++ b/rig-core/examples/multi_agent.rs @@ -1,8 +1,8 @@ use anyhow::Result; +use rig::cli_chatbot::ChatBotBuilder; use rig::prelude::*; use rig::{ agent::{Agent, AgentBuilder}, - cli_chatbot::ChatbotBuilder, completion::{Chat, CompletionModel, PromptError, ToolDefinition}, providers::openai::Client as OpenAIClient, tool::Tool, @@ -88,7 +88,7 @@ async fn main() -> Result<(), anyhow::Error> { .build(); // Spin up a CLI chatbot using the multi-agent system - let chatbot = ChatbotBuilder::new() + let chatbot = ChatBotBuilder::new() .agent(multi_agent_system) .multi_turn_depth(1) .build(); diff --git a/rig-core/examples/openai_agent_completions_api_otel.rs b/rig-core/examples/openai_agent_completions_api_otel.rs new file mode 100644 index 000000000..e397bdde2 --- /dev/null +++ b/rig-core/examples/openai_agent_completions_api_otel.rs @@ -0,0 +1,60 @@ +//! This example shows how you can use OpenAI's Completions API. +//! By default, the OpenAI integration uses the Responses API. However, for the sake of backwards compatibility you may wish to use the Completions API. + +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::Resource; +use rig::completion::Prompt; +use rig::prelude::*; + +use opentelemetry::trace::TracerProvider as _; +use opentelemetry_sdk::trace::SdkTracerProvider; +use rig::providers; +use tracing::Level; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + let exporter = opentelemetry_otlp::SpanExporter::builder() + .with_http() + .with_protocol(opentelemetry_otlp::Protocol::HttpBinary) + .build()?; + // Create a new OpenTelemetry trace pipeline that prints to stdout + let provider = SdkTracerProvider::builder() + .with_batch_exporter(exporter) + .with_resource(Resource::builder().with_service_name("rig-demo").build()) + .build(); + let tracer = provider.tracer("readme_example"); + + // Create a tracing layer with the configured tracer + let otel_layer = tracing_opentelemetry::layer().with_tracer(tracer); + let filter_layer = tracing_subscriber::filter::EnvFilter::builder() + .with_default_directive(Level::INFO.into()) + .from_env_lossy(); + + let fmt_layer = tracing_subscriber::fmt::layer().compact(); + + // Use the tracing subscriber `Registry`, or any other subscriber + // that impls `LookupSpan` + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .with(otel_layer) + .init(); + + // Create OpenAI client + let agent = providers::openai::Client::from_env() + .completion_model("gpt-4o") + .completions_api() + .into_agent_builder() + .preamble("You are a helpful assistant") + .build(); + + let res = agent.prompt("Hello world!").await.unwrap(); + + println!("GPT-4o: {res}"); + + let _ = provider.shutdown(); + + Ok(()) +} diff --git a/rig-core/examples/openai_streaming_with_tools_otel.rs b/rig-core/examples/openai_streaming_with_tools_otel.rs new file mode 100644 index 000000000..0e21f5973 --- /dev/null +++ b/rig-core/examples/openai_streaming_with_tools_otel.rs @@ -0,0 +1,154 @@ +use anyhow::Result; +use rig::agent::stream_to_stdout; +use rig::prelude::*; + +use rig::{completion::ToolDefinition, providers, streaming::StreamingPrompt, tool::Tool}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use opentelemetry::trace::TracerProvider; +use opentelemetry_otlp::WithExportConfig; +use opentelemetry_sdk::Resource; +use opentelemetry_sdk::trace::SdkTracerProvider; +use tracing::Level; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; + +impl Tool for Adder { + const NAME: &'static str = "add"; + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + }, + "required": ["x", "y"], + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; + +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to subtract from" + }, + "y": { + "type": "number", + "description": "The number to subtract" + } + }, + "required": ["x", "y"], + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + let exporter = opentelemetry_otlp::SpanExporter::builder() + .with_http() + .with_protocol(opentelemetry_otlp::Protocol::HttpBinary) + .build()?; + // Create a new OpenTelemetry trace pipeline that prints to stdout + let provider = SdkTracerProvider::builder() + .with_batch_exporter(exporter) + .with_resource(Resource::builder().with_service_name("rig-demo").build()) + .build(); + let tracer = provider.tracer("readme_example"); + + // Create a tracing layer with the configured tracer + let otel_layer = tracing_opentelemetry::layer().with_tracer(tracer); + let filter_layer = tracing_subscriber::filter::EnvFilter::builder() + .with_default_directive(Level::INFO.into()) + .from_env_lossy(); + + let fmt_layer = tracing_subscriber::fmt::layer().pretty(); + + // Use the tracing subscriber `Registry`, or any other subscriber + // that impls `LookupSpan` + tracing_subscriber::registry() + .with(filter_layer) + .with(fmt_layer) + .with(otel_layer) + .init(); + + // Create agent with a single context prompt and two tools + let calculator_agent = providers::openai::Client::from_env() + .agent(providers::openai::GPT_4O) + .preamble( + "You are a calculator here to help the user perform arithmetic + operations. Use the tools provided to answer the user's question. + make your answer long, so we can test the streaming functionality, + like 20 words", + ) + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .name("Bob") + .build(); + + let mut stream = calculator_agent.stream_prompt("Calculate 2 - 5").await; + + let res = stream_to_stdout(&mut stream).await?; + + println!("Token usage response: {usage:?}", usage = res.usage()); + println!("Final text response: {message:?}", message = res.response()); + + let _ = provider.shutdown(); + + Ok(()) +} diff --git a/rig-core/examples/otel/Dockerfile b/rig-core/examples/otel/Dockerfile new file mode 100644 index 000000000..a4e48a394 --- /dev/null +++ b/rig-core/examples/otel/Dockerfile @@ -0,0 +1,6 @@ +# Start from the official OpenTelemetry Collector Contrib image +FROM otel/opentelemetry-collector-contrib:0.135.0 + +# Copy your local config into the container +# Replace `config.yaml` with your actual filename if different +COPY ./config.yaml /etc/otelcol-contrib/config.yaml diff --git a/rig-core/examples/otel/config.yaml b/rig-core/examples/otel/config.yaml new file mode 100644 index 000000000..f28eb9882 --- /dev/null +++ b/rig-core/examples/otel/config.yaml @@ -0,0 +1,28 @@ +receivers: + otlp: + protocols: + http: + endpoint: 0.0.0.0:4318 + +processors: + transform: + trace_statements: + - context: span + statements: + # Rename span if it's "invoke_agent" and has an agent attribute + - set(name, attributes["gen_ai.agent.name"]) where name == "invoke_agent" and attributes["gen_ai.agent.name"] != nil + +exporters: + debug: + verbosity: detailed + otlphttp/langfuse: + endpoint: "https://cloud.langfuse.com/api/public/otel" + headers: + Authorization: "Basic ${AUTH_STRING}" + +service: + pipelines: + traces: + receivers: [otlp] + processors: [transform] + exporters: [otlphttp/langfuse, debug] diff --git a/rig-core/examples/pdf_agent.rs b/rig-core/examples/pdf_agent.rs index ab309929f..394564091 100644 --- a/rig-core/examples/pdf_agent.rs +++ b/rig-core/examples/pdf_agent.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use rig::cli_chatbot::ChatbotBuilder; +use rig::cli_chatbot::ChatBotBuilder; use rig::prelude::*; use rig::{ Embed, embeddings::EmbeddingsBuilder, loaders::PdfFileLoader, providers::openai, @@ -96,7 +96,7 @@ async fn main() -> Result<()> { println!("Starting CLI chatbot..."); // Start interactive CLI - let chatbot = ChatbotBuilder::new() + let chatbot = ChatBotBuilder::new() .agent(rag_agent) .multi_turn_depth(10) .build(); diff --git a/rig-core/src/agent/builder.rs b/rig-core/src/agent/builder.rs index 7ecd0e389..07cac167d 100644 --- a/rig-core/src/agent/builder.rs +++ b/rig-core/src/agent/builder.rs @@ -2,6 +2,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ completion::{CompletionModel, Document}, + message::ToolChoice, tool::{Tool, ToolSet}, vector_store::VectorStoreIndexDyn, }; @@ -39,6 +40,8 @@ where { /// Name of the agent used for logging and debugging name: Option, + /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats. + description: Option, /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) model: M, /// System prompt @@ -59,6 +62,8 @@ where temperature: Option, /// Actual tool implementations tools: ToolSet, + /// Whether or not the underlying LLM should be forced to use a tool before providing a response. + tool_choice: Option, } impl AgentBuilder @@ -68,6 +73,7 @@ where pub fn new(model: M) -> Self { Self { name: None, + description: None, model, preamble: None, static_context: vec![], @@ -78,6 +84,7 @@ where dynamic_context: vec![], dynamic_tools: vec![], tools: ToolSet::default(), + tool_choice: None, } } @@ -87,6 +94,12 @@ where self } + /// Set the description of the agent + pub fn description(mut self, description: &str) -> Self { + self.description = Some(description.into()); + self + } + /// Set the system prompt pub fn preamble(mut self, preamble: &str) -> Self { self.preamble = Some(preamble.into()); @@ -149,6 +162,11 @@ where self } + pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self { + self.tool_choice = Some(tool_choice); + self + } + /// Add some dynamic tools to the agent. On each prompt, `sample` tools from the /// dynamic toolset will be inserted in the request. pub fn dynamic_tools( @@ -184,6 +202,7 @@ where pub fn build(self) -> Agent { Agent { name: self.name, + description: self.description, model: Arc::new(self.model), preamble: self.preamble, static_context: self.static_context, @@ -191,6 +210,7 @@ where temperature: self.temperature, max_tokens: self.max_tokens, additional_params: self.additional_params, + tool_choice: self.tool_choice, dynamic_context: Arc::new(self.dynamic_context), dynamic_tools: Arc::new(self.dynamic_tools), tools: Arc::new(self.tools), diff --git a/rig-core/src/agent/completion.rs b/rig-core/src/agent/completion.rs index 4c66e1b61..adc85259f 100644 --- a/rig-core/src/agent/completion.rs +++ b/rig-core/src/agent/completion.rs @@ -5,6 +5,7 @@ use crate::{ Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, GetTokenUsage, Message, Prompt, PromptError, }, + message::ToolChoice, streaming::{StreamingChat, StreamingCompletion, StreamingPrompt}, tool::ToolSet, vector_store::{VectorStoreError, request::VectorSearchRequest}, @@ -42,6 +43,8 @@ where { /// Name of the agent used for logging and debugging pub name: Option, + /// Agent description. Primarily useful when using sub-agents as part of an agent workflow and converting agents to other formats. + pub description: Option, /// Completion model (e.g.: OpenAI's gpt-3.5-turbo-1106, Cohere's command-r) pub model: Arc, /// System prompt @@ -62,6 +65,8 @@ where pub dynamic_tools: Arc)>>, /// Actual tool implementations pub tools: Arc, + /// Whether or not the underlying LLM should be forced to use a tool before providing a response. + pub tool_choice: Option, } impl Agent @@ -72,19 +77,12 @@ where pub(crate) fn name(&self) -> &str { self.name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME) } - - /// Returns the name of the agent as an owned variable. - /// Useful in some cases where having the agent name as an owned variable is required. - pub(crate) fn name_owned(&self) -> String { - self.name.clone().unwrap_or(UNKNOWN_AGENT_NAME.to_string()) - } } impl Completion for Agent where M: CompletionModel, { - #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))] async fn completion( &self, prompt: impl Into + Send, @@ -228,7 +226,6 @@ impl Prompt for Agent where M: CompletionModel, { - #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))] fn prompt( &self, prompt: impl Into + Send, @@ -272,7 +269,6 @@ impl StreamingCompletion for Agent where M: CompletionModel, { - #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))] async fn stream_completion( &self, prompt: impl Into + Send, @@ -289,7 +285,6 @@ where M: CompletionModel + 'static, M::StreamingResponse: GetTokenUsage, { - #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))] fn stream_prompt(&self, prompt: impl Into + Send) -> StreamingPromptRequest { let arc = Arc::new(self.clone()); StreamingPromptRequest::new(arc, prompt) diff --git a/rig-core/src/agent/prompt_request/mod.rs b/rig-core/src/agent/prompt_request/mod.rs index 7aee18fb6..7a400073d 100644 --- a/rig-core/src/agent/prompt_request/mod.rs +++ b/rig-core/src/agent/prompt_request/mod.rs @@ -1,8 +1,14 @@ pub(crate) mod streaming; -use std::{future::IntoFuture, marker::PhantomData}; +use std::{ + future::IntoFuture, + marker::PhantomData, + sync::atomic::{AtomicU64, Ordering}, +}; +use tracing::{Instrument, span::Id}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream}; +use tracing::info_span; use crate::{ OneOrMany, @@ -146,7 +152,6 @@ where #[allow(unused_variables)] /// Called after the prompt is sent to the model and a response is received. - /// This function is for non-streamed responses. Please refer to `on_stream_completion_response_finish` for streamed responses. fn on_completion_response( &self, prompt: &Message, @@ -155,16 +160,6 @@ where async {} } - #[allow(unused_variables)] - /// Called after the model provider has finished streaming a text response from their completion API to the client. - fn on_stream_completion_response_finish( - &self, - prompt: &Message, - response: &::StreamingResponse, - ) -> impl Future + Send { - async {} - } - #[allow(unused_variables)] /// Called before a tool is invoked. fn on_tool_call(&self, tool_name: &str, args: &str) -> impl Future + Send { @@ -244,18 +239,37 @@ where M: CompletionModel, P: PromptHook, { - #[tracing::instrument(skip(self), fields(agent_name = self.agent.name()))] async fn send(self) -> Result { + let agent_span = if tracing::Span::current().is_disabled() { + info_span!( + "invoke_agent", + gen_ai.operation.name = "invoke_agent", + gen_ai.agent.name = self.agent.name(), + gen_ai.system_instructions = self.agent.preamble, + gen_ai.prompt = tracing::field::Empty, + gen_ai.completion = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + let agent = self.agent; let chat_history = if let Some(history) = self.chat_history { - history.push(self.prompt); + history.push(self.prompt.to_owned()); history } else { - &mut vec![self.prompt] + &mut vec![self.prompt.to_owned()] }; + if let Some(text) = self.prompt.rag_text() { + agent_span.record("gen_ai.prompt", text); + } + let mut current_max_depth = 0; let mut usage = Usage::new(); + let current_span_id: AtomicU64 = AtomicU64::new(0); // We need to do at least 2 loops for 1 roundtrip (user expects normal message) let last_prompt = loop { @@ -282,6 +296,33 @@ where hook.on_completion_call(&prompt, &chat_history[..chat_history.len() - 1]) .await; } + let span = tracing::Span::current(); + let chat_span = info_span!( + target: "rig::agent_chat", + parent: &span, + "chat", + gen_ai.operation.name = "chat", + gen_ai.system_instructions = self.agent.preamble, + gen_ai.provider.name = tracing::field::Empty, + gen_ai.request.model = tracing::field::Empty, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ); + + let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 { + let id = Id::from_u64(current_span_id.load(Ordering::SeqCst)); + chat_span.follows_from(id).to_owned() + } else { + chat_span + }; + + if let Some(id) = chat_span.id() { + current_span_id.store(id.into_u64(), Ordering::SeqCst); + }; let resp = agent .completion( @@ -290,6 +331,7 @@ where ) .await? .send() + .instrument(chat_span.clone()) .await?; usage += resp.usage; @@ -325,6 +367,10 @@ where tracing::info!("Depth reached: {}/{}", current_max_depth, self.max_depth); } + agent_span.record("gen_ai.completion", &merged_texts); + agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens); + agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens); + // If there are no tool calls, depth is not relevant, we can just return the merged text response. return Ok(PromptResponse::new(merged_texts, usage)); } @@ -334,10 +380,36 @@ where .then(|choice| { let hook1 = hook.clone(); let hook2 = hook.clone(); + + let tool_span = info_span!( + "execute_tool", + gen_ai.operation.name = "execute_tool", + gen_ai.tool.type = "function", + gen_ai.tool.name = tracing::field::Empty, + gen_ai.tool.call.id = tracing::field::Empty, + gen_ai.tool.call.arguments = tracing::field::Empty, + gen_ai.tool.call.result = tracing::field::Empty + ); + + let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 { + let id = Id::from_u64(current_span_id.load(Ordering::SeqCst)); + tool_span.follows_from(id).to_owned() + } else { + tool_span + }; + + if let Some(id) = tool_span.id() { + current_span_id.store(id.into_u64(), Ordering::SeqCst); + }; + async move { if let AssistantContent::ToolCall(tool_call) = choice { let tool_name = &tool_call.function.name; let args = tool_call.function.arguments.to_string(); + let tool_span = tracing::Span::current(); + tool_span.record("gen_ai.tool.name", tool_name); + tool_span.record("gen_ai.tool.call.id", &tool_call.id); + tool_span.record("gen_ai.tool.call.arguments", &args); if let Some(hook) = hook1 { hook.on_tool_call(tool_name, &args).await; } @@ -346,6 +418,10 @@ where hook.on_tool_result(tool_name, &args, &output.to_string()) .await; } + tool_span.record("gen_ai.tool.call.result", &output); + tracing::info!( + "executed tool {tool_name} with args {args}. result: {output}" + ); if let Some(call_id) = tool_call.call_id.clone() { Ok(UserContent::tool_result_with_call_id( tool_call.id.clone(), @@ -364,6 +440,7 @@ where ) } } + .instrument(tool_span) }) .collect::>>() .await diff --git a/rig-core/src/agent/prompt_request/streaming.rs b/rig-core/src/agent/prompt_request/streaming.rs index 12796e170..eb331ade3 100644 --- a/rig-core/src/agent/prompt_request/streaming.rs +++ b/rig-core/src/agent/prompt_request/streaming.rs @@ -1,6 +1,5 @@ use crate::{ OneOrMany, - agent::prompt_request::PromptHook, completion::GetTokenUsage, message::{AssistantContent, Reasoning, ToolResultContent, UserContent}, streaming::{StreamedAssistantContent, StreamingCompletion}, @@ -9,6 +8,8 @@ use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; use std::{pin::Pin, sync::Arc}; use tokio::sync::RwLock; +use tracing::info_span; +use tracing_futures::Instrument; use crate::{ agent::Agent, @@ -91,7 +92,7 @@ pub enum StreamingError { pub struct StreamingPromptRequest where M: CompletionModel, - P: PromptHook + 'static, + P: StreamingPromptHook + 'static, { /// The prompt message to send to the model prompt: Message, @@ -110,7 +111,7 @@ impl StreamingPromptRequest where M: CompletionModel + 'static, ::StreamingResponse: Send + GetTokenUsage, - P: PromptHook, + P: StreamingPromptHook, { /// Create a new PromptRequest with the given prompt and model pub fn new(agent: Arc>, prompt: impl Into) -> Self { @@ -139,7 +140,7 @@ where /// Attach a per-request hook for tool call events pub fn with_hook(self, hook: P2) -> StreamingPromptRequest where - P2: PromptHook, + P2: StreamingPromptHook, { StreamingPromptRequest { prompt: self.prompt, @@ -152,101 +153,152 @@ where #[cfg_attr(feature = "worker", worker::send)] async fn send(self) -> StreamingResult { - let agent_name = self.agent.name_owned(); - - #[tracing::instrument(skip_all, fields(agent_name = agent_name))] - fn inner( - req: StreamingPromptRequest, - agent_name: String, - ) -> StreamingResult - where - M: CompletionModel + 'static, - ::StreamingResponse: Send, - P: PromptHook + 'static, - { - let prompt = req.prompt; - let agent = req.agent; - - let chat_history = if let Some(mut history) = req.chat_history { - history.push(prompt.clone()); - Arc::new(RwLock::new(history)) - } else { - Arc::new(RwLock::new(vec![prompt.clone()])) - }; - - let mut current_max_depth = 0; - let mut last_prompt_error = String::new(); - - let mut last_text_response = String::new(); - let mut is_text_response = false; - let mut max_depth_reached = false; - - let mut aggregated_usage = crate::completion::Usage::new(); - - Box::pin(async_stream::stream! { - let mut current_prompt = prompt.clone(); - let mut did_call_tool = false; - - 'outer: loop { - if current_max_depth > req.max_depth + 1 { - last_prompt_error = current_prompt.rag_text().unwrap_or_default(); - max_depth_reached = true; - break; - } + let agent_span = if tracing::Span::current().is_disabled() { + info_span!( + "invoke_agent", + gen_ai.operation.name = "invoke_agent", + gen_ai.agent.name = self.agent.name(), + gen_ai.system_instructions = self.agent.preamble, + gen_ai.prompt = tracing::field::Empty, + gen_ai.completion = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + let prompt = self.prompt; + if let Some(text) = prompt.rag_text() { + agent_span.record("gen_ai.prompt", text); + } - current_max_depth += 1; + let agent = self.agent; - if req.max_depth > 1 { - tracing::info!( - "Current conversation depth: {}/{}", - current_max_depth, - req.max_depth - ); - } + let chat_history = if let Some(history) = self.chat_history { + Arc::new(RwLock::new(history)) + } else { + Arc::new(RwLock::new(vec![])) + }; - if let Some(ref hook) = req.hook { - let reader = chat_history.read().await; - let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history"); - let chat_history_except_last = reader[..reader.len() - 1].to_vec(); + let mut current_max_depth = 0; + let mut last_prompt_error = String::new(); - hook.on_completion_call(&prompt, &chat_history_except_last) - .await; - } + let mut last_text_response = String::new(); + let mut is_text_response = false; + let mut max_depth_reached = false; + let mut aggregated_usage = crate::completion::Usage::new(); - let mut stream = agent - .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone()) - .await? - .stream() - .await?; + Box::pin(async_stream::stream! { + let _guard = agent_span.enter(); + let mut current_prompt = prompt.clone(); + let mut did_call_tool = false; - chat_history.write().await.push(current_prompt.clone()); + 'outer: loop { + if current_max_depth > self.max_depth + 1 { + last_prompt_error = current_prompt.rag_text().unwrap_or_default(); + max_depth_reached = true; + break; + } - let mut tool_calls = vec![]; - let mut tool_results = vec![]; + current_max_depth += 1; - while let Some(content) = stream.next().await { - match content { - Ok(StreamedAssistantContent::Text(text)) => { - if !is_text_response { - last_text_response = String::new(); - is_text_response = true; - } - last_text_response.push_str(&text.text); - yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text))); - did_call_tool = false; - }, - Ok(StreamedAssistantContent::ToolCall(tool_call)) => { - if let Some(ref hook) = req.hook { + if self.max_depth > 1 { + tracing::info!( + "Current conversation depth: {}/{}", + current_max_depth, + self.max_depth + ); + } + + if let Some(ref hook) = self.hook { + let reader = chat_history.read().await; + let prompt = reader.last().cloned().expect("there should always be at least one message in the chat history"); + let chat_history_except_last = reader[..reader.len() - 1].to_vec(); + + hook.on_completion_call(&prompt, &chat_history_except_last) + .await; + } + + let chat_stream_span = info_span!( + target: "rig::agent_chat", + parent: tracing::Span::current(), + "chat_streaming", + gen_ai.operation.name = "chat", + gen_ai.system_instructions = &agent.preamble, + gen_ai.provider.name = tracing::field::Empty, + gen_ai.request.model = tracing::field::Empty, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ); + + let mut stream = tracing::Instrument::instrument( + agent + .stream_completion(current_prompt.clone(), (*chat_history.read().await).clone()) + .await? + .stream(), chat_stream_span + ) + + .await?; + + chat_history.write().await.push(current_prompt.clone()); + + let mut tool_calls = vec![]; + let mut tool_results = vec![]; + + while let Some(content) = stream.next().await { + match content { + Ok(StreamedAssistantContent::Text(text)) => { + if !is_text_response { + last_text_response = String::new(); + is_text_response = true; + } + last_text_response.push_str(&text.text); + if let Some(ref hook) = self.hook { + hook.on_text_delta(&text.text, &last_text_response).await; + } + yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Text(text))); + did_call_tool = false; + }, + Ok(StreamedAssistantContent::ToolCall(tool_call)) => { + let tool_span = info_span!( + parent: tracing::Span::current(), + "execute_tool", + gen_ai.operation.name = "execute_tool", + gen_ai.tool.type = "function", + gen_ai.tool.name = tracing::field::Empty, + gen_ai.tool.call.id = tracing::field::Empty, + gen_ai.tool.call.arguments = tracing::field::Empty, + gen_ai.tool.call.result = tracing::field::Empty + ); + + async { + let tool_span = tracing::Span::current(); + if let Some(ref hook) = self.hook { hook.on_tool_call(&tool_call.function.name, &tool_call.function.arguments.to_string()).await; } - let tool_result = - agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await?; - if let Some(ref hook) = req.hook { + tool_span.record("gen_ai.tool.name", &tool_call.function.name); + tool_span.record("gen_ai.tool.call.arguments", tool_call.function.arguments.to_string()); + + let tool_result = match + agent.tools.call(&tool_call.function.name, tool_call.function.arguments.to_string()).await { + Ok(thing) => thing, + Err(e) => e.to_string() + }; + + tool_span.record("gen_ai.tool.call.result", &tool_result); + + if let Some(ref hook) = self.hook { hook.on_tool_result(&tool_call.function.name, &tool_call.function.arguments.to_string(), &tool_result.to_string()) - .await; + .await; } + let tool_call_msg = AssistantContent::ToolCall(tool_call.clone()); tool_calls.push(tool_call_msg); @@ -254,87 +306,89 @@ where did_call_tool = true; // break; - }, - Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => { - chat_history.write().await.push(rig::message::Message::Assistant { - id: None, - content: OneOrMany::one(AssistantContent::Reasoning(Reasoning { - reasoning: reasoning.clone(), id: id.clone() - })) - }); - yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id }))); - did_call_tool = false; - }, - Ok(StreamedAssistantContent::Final(final_resp)) => { - if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; }; - if is_text_response { - if let Some(ref hook) = req.hook { - hook.on_stream_completion_response_finish(&prompt, &final_resp).await; - } - yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp))); - is_text_response = false; + }.instrument(tool_span).await + }, + Ok(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id })) => { + chat_history.write().await.push(rig::message::Message::Assistant { + id: None, + content: OneOrMany::one(AssistantContent::Reasoning(Reasoning { + reasoning: reasoning.clone(), id: id.clone() + })) + }); + yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Reasoning(rig::message::Reasoning { reasoning, id }))); + did_call_tool = false; + }, + Ok(StreamedAssistantContent::Final(final_resp)) => { + if let Some(usage) = final_resp.token_usage() { aggregated_usage += usage; }; + if is_text_response { + if let Some(ref hook) = self.hook { + hook.on_stream_completion_response_finish(&prompt, &final_resp).await; } - } - Err(e) => { - yield Err(e.into()); - break 'outer; + tracing::Span::current().record("gen_ai.completion", &last_text_response); + yield Ok(MultiTurnStreamItem::stream_item(StreamedAssistantContent::Final(final_resp))); + is_text_response = false; } } - } - - // Add (parallel) tool calls to chat history - if !tool_calls.is_empty() { - chat_history.write().await.push(Message::Assistant { - id: None, - content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"), - }); - } - - // Add tool results to chat history - for (id, call_id, tool_result) in tool_results { - if let Some(call_id) = call_id { - chat_history.write().await.push(Message::User { - content: OneOrMany::one(UserContent::tool_result_with_call_id( - &id, - call_id.clone(), - OneOrMany::one(ToolResultContent::text(&tool_result)), - )), - }); - } else { - chat_history.write().await.push(Message::User { - content: OneOrMany::one(UserContent::tool_result( - &id, - OneOrMany::one(ToolResultContent::text(&tool_result)), - )), - }); + Err(e) => { + yield Err(e.into()); + break 'outer; } - } + } - // Set the current prompt to the last message in the chat history - current_prompt = match chat_history.write().await.pop() { - Some(prompt) => prompt, - None => unreachable!("Chat history should never be empty at this point"), - }; + // Add (parallel) tool calls to chat history + if !tool_calls.is_empty() { + chat_history.write().await.push(Message::Assistant { + id: None, + content: OneOrMany::many(tool_calls.clone()).expect("Impossible EmptyListError"), + }); + } - if !did_call_tool { - yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage)); - break; + // Add tool results to chat history + for (id, call_id, tool_result) in tool_results { + if let Some(call_id) = call_id { + chat_history.write().await.push(Message::User { + content: OneOrMany::one(UserContent::tool_result_with_call_id( + &id, + call_id.clone(), + OneOrMany::one(ToolResultContent::text(&tool_result)), + )), + }); + } else { + chat_history.write().await.push(Message::User { + content: OneOrMany::one(UserContent::tool_result( + &id, + OneOrMany::one(ToolResultContent::text(&tool_result)), + )), + }); } } - if max_depth_reached { - yield Err(Box::new(PromptError::MaxDepthError { - max_depth: req.max_depth, - chat_history: Box::new((*chat_history.read().await).clone()), - prompt: last_prompt_error.into(), - }).into()); - } + // Set the current prompt to the last message in the chat history + current_prompt = match chat_history.write().await.pop() { + Some(prompt) => prompt, + None => unreachable!("Chat history should never be empty at this point"), + }; + + if !did_call_tool { + let current_span = tracing::Span::current(); + current_span.record("gen_ai.usage.input_tokens", aggregated_usage.input_tokens); + current_span.record("gen_ai.usage.output_tokens", aggregated_usage.output_tokens); + tracing::info!("Agent multi-turn stream finished"); + yield Ok(MultiTurnStreamItem::final_response(&last_text_response, aggregated_usage)); + break; + } + } - }) - } + if max_depth_reached { + yield Err(Box::new(PromptError::MaxDepthError { + max_depth: self.max_depth, + chat_history: Box::new((*chat_history.read().await).clone()), + prompt: last_prompt_error.clone().into(), + }).into()); + } - inner(self, agent_name) + }) } } @@ -342,7 +396,7 @@ impl IntoFuture for StreamingPromptRequest where M: CompletionModel + 'static, ::StreamingResponse: Send, - P: PromptHook + 'static, + P: StreamingPromptHook + 'static, { type Output = StreamingResult; // what `.await` returns type IntoFuture = Pin + Send>>; @@ -353,7 +407,7 @@ where } } -/// helper function to stream a completion request to stdout +/// helper function to stream a completion selfuest to stdout pub async fn stream_to_stdout( stream: &mut StreamingResult, ) -> Result { @@ -384,3 +438,59 @@ pub async fn stream_to_stdout( Ok(final_res) } + +// dead code allowed because of functions being left empty to allow for users to not have to implement every single function +/// Trait for per-request hooks to observe tool call events. +pub trait StreamingPromptHook: Clone + Send + Sync +where + M: CompletionModel, +{ + #[allow(unused_variables)] + /// Called before the prompt is sent to the model + fn on_completion_call( + &self, + prompt: &Message, + history: &[Message], + ) -> impl Future + Send { + async {} + } + + #[allow(unused_variables)] + /// Called when receiving a text delta + fn on_text_delta( + &self, + text_delta: &str, + aggregated_text: &str, + ) -> impl Future + Send { + async {} + } + + #[allow(unused_variables)] + /// Called after the model provider has finished streaming a text response from their completion API to the client. + fn on_stream_completion_response_finish( + &self, + prompt: &Message, + response: &::StreamingResponse, + ) -> impl Future + Send { + async {} + } + + #[allow(unused_variables)] + /// Called before a tool is invoked. + fn on_tool_call(&self, tool_name: &str, args: &str) -> impl Future + Send { + async {} + } + + #[allow(unused_variables)] + /// Called after a tool is invoked (and a result has been returned). + fn on_tool_result( + &self, + tool_name: &str, + args: &str, + result: &str, + ) -> impl Future + Send { + async {} + } +} + +impl StreamingPromptHook for () where M: CompletionModel {} diff --git a/rig-core/src/agent/tool.rs b/rig-core/src/agent/tool.rs index 329e8b82f..e7dad6f87 100644 --- a/rig-core/src/agent/tool.rs +++ b/rig-core/src/agent/tool.rs @@ -20,15 +20,21 @@ impl Tool for Agent { type Output = String; async fn definition(&self, _prompt: String) -> ToolDefinition { + let description = format!( + " + Prompt a sub-agent to do a task for you. + + Agent name: {name} + Agent description: {description} + Agent system prompt: {sysprompt} + ", + name = self.name(), + description = self.description.clone().unwrap_or_default(), + sysprompt = self.preamble.clone().unwrap_or_default() + ); ToolDefinition { name: ::name(self), - description: format!( - "A tool that allows the agent to call another agent by prompting it. The preamble - of that agent follows: - --- - {}", - self.preamble.clone().unwrap_or_default() - ), + description, parameters: serde_json::to_value(schema_for!(AgentToolArgs)) .expect("converting JSON schema to JSON value should never fail"), } diff --git a/rig-core/src/cli_chatbot.rs b/rig-core/src/cli_chatbot.rs index 66627d877..4ca023e0a 100644 --- a/rig-core/src/cli_chatbot.rs +++ b/rig-core/src/cli_chatbot.rs @@ -1,238 +1,223 @@ -use std::io::{self, Write}; - -use futures::StreamExt; - use crate::{ - agent::{Agent, Text, prompt_request::streaming::MultiTurnStreamItem}, - completion::{Chat, CompletionError, CompletionModel, Message, PromptError}, + agent::{Agent, MultiTurnStreamItem, Text}, + completion::{Chat, CompletionError, CompletionModel, PromptError, Usage}, + message::Message, streaming::{StreamedAssistantContent, StreamingPrompt}, }; +use futures::StreamExt; +use std::io::{self, Write}; -/// Type-state representing an empty `agent` field in `ChatbotBuilder` -pub struct AgentNotSet; - -/// Builder pattern for CLI chatbots. -/// -/// # Example -/// ```rust -/// let chatbot = ChatbotBuilder::new().agent(my_agent).show_usage().build(); -/// -/// chatbot.run().await?; -pub struct ChatbotBuilder { - agent: A, +pub struct NoImplProvided; + +pub struct ChatImpl(T) +where + T: Chat; + +pub struct AgentImpl +where + M: CompletionModel + 'static, +{ + agent: Agent, multi_turn_depth: usize, show_usage: bool, + usage: Usage, } -impl Default for ChatbotBuilder { - fn default() -> Self { - ChatbotBuilder { - agent: AgentNotSet, - multi_turn_depth: 0, - show_usage: false, - } +pub struct ChatBotBuilder(T); + +pub struct ChatBot(T); + +/// Trait to abstract message behavior away from cli_chat/`run` loop +#[allow(private_interfaces)] +trait CliChat { + async fn request(&mut self, prompt: &str, history: Vec) + -> Result; + + fn show_usage(&self) -> bool { + false + } + + fn usage(&self) -> Option { + None } } -impl ChatbotBuilder { - pub fn new() -> Self { - Default::default() +impl CliChat for ChatImpl +where + T: Chat, +{ + async fn request( + &mut self, + prompt: &str, + history: Vec, + ) -> Result { + let res = self.0.chat(prompt, history).await?; + println!("{res}"); + + Ok(res) } +} - /// Sets the agent that will be used to drive the CLI interface - pub fn agent(self, agent: Agent) -> ChatbotBuilder> - where - M: CompletionModel + 'static, - { - ChatbotBuilder { - agent, - multi_turn_depth: self.multi_turn_depth, - show_usage: self.show_usage, +impl CliChat for AgentImpl +where + M: CompletionModel + 'static, +{ + async fn request( + &mut self, + prompt: &str, + history: Vec, + ) -> Result { + let mut response_stream = self + .agent + .stream_prompt(prompt) + .with_history(history) + .multi_turn(self.multi_turn_depth) + .await; + + let mut acc = String::new(); + + loop { + let Some(chunk) = response_stream.next().await else { + break Ok(acc); + }; + + match chunk { + Ok(MultiTurnStreamItem::StreamItem(StreamedAssistantContent::Text(Text { + text, + }))) => { + print!("{}", text); + acc.push_str(&text); + } + Ok(MultiTurnStreamItem::FinalResponse(final_response)) => { + self.usage = final_response.usage(); + } + Err(e) => { + break Err(PromptError::CompletionError( + CompletionError::ResponseError(e.to_string()), + )); + } + _ => continue, + } } } + + fn show_usage(&self) -> bool { + self.show_usage + } + + fn usage(&self) -> Option { + Some(self.usage) + } } -impl ChatbotBuilder { - /// Sets the `show_usage` flag, so that after a request the number of tokens - /// in the input and output will be printed - pub fn show_usage(self) -> Self { - Self { - show_usage: true, - ..self - } +impl Default for ChatBotBuilder { + fn default() -> Self { + Self(NoImplProvided) } +} - /// Sets the maximum depth for multi-turn, i.e. toolcalls - pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self { - Self { - multi_turn_depth, - ..self - } +impl ChatBotBuilder { + pub fn new() -> Self { + Self::default() + } + + pub fn agent( + self, + agent: Agent, + ) -> ChatBotBuilder> { + ChatBotBuilder(AgentImpl { + agent, + multi_turn_depth: 1, + show_usage: false, + usage: Usage::default(), + }) + } + + pub fn chat(self, chatbot: T) -> ChatBotBuilder> { + ChatBotBuilder(ChatImpl(chatbot)) } } -impl ChatbotBuilder> +impl ChatBotBuilder> where - M: CompletionModel + 'static, + T: Chat, { - /// Consumes the `ChatbotBuilder`, returning a `Chatbot` which can be run - pub fn build(self) -> Chatbot { - Chatbot { - agent: self.agent, - multi_turn_depth: self.multi_turn_depth, - show_usage: self.show_usage, - } + pub fn build(self) -> ChatBot> { + let ChatBotBuilder(chat_impl) = self; + ChatBot(chat_impl) } } -/// A CLI chatbot. -/// Only takes [Agent] types unlike [cli_chatbot] which takes any `impl Chat` type. -/// -/// # Example -/// ```rust -/// let chatbot = ChatbotBuilder::new().agent(my_agent).show_usage().build(); -/// -/// chatbot.run().await?; -pub struct Chatbot +impl ChatBotBuilder> where M: CompletionModel + 'static, { - agent: Agent, - multi_turn_depth: usize, - show_usage: bool, + pub fn multi_turn_depth(self, multi_turn_depth: usize) -> Self { + ChatBotBuilder(AgentImpl { + multi_turn_depth, + ..self.0 + }) + } + + pub fn show_usage(self) -> Self { + ChatBotBuilder(AgentImpl { + show_usage: true, + ..self.0 + }) + } + + pub fn build(self) -> ChatBot> { + ChatBot(self.0) + } } -impl Chatbot +#[allow(private_bounds)] +impl ChatBot where - M: CompletionModel + 'static, + T: CliChat, { - pub async fn run(self) -> Result<(), PromptError> { + pub async fn run(mut self) -> Result<(), PromptError> { let stdin = io::stdin(); let mut stdout = io::stdout(); - let mut chat_log = vec![]; - - println!("Welcome to the chatbot! Type 'exit' to quit."); + let mut history = vec![]; loop { print!("> "); - // Flush stdout to ensure the prompt appears before input stdout.flush().unwrap(); let mut input = String::new(); match stdin.read_line(&mut input) { Ok(_) => { - // Remove the newline character from the input let input = input.trim(); - - if input.is_empty() { - continue; - } - - // Check for a command to exit if input == "exit" { break; } - tracing::info!("Prompt:\n{}\n", input); - - let mut usage = None; - let mut response = String::new(); + tracing::info!("Prompt:\n{input}\n"); println!(); println!("========================== Response ============================"); - let mut stream_response = self - .agent - .stream_prompt(input) - .with_history(chat_log.clone()) - .multi_turn(self.multi_turn_depth) - .await; - - while let Some(chunk) = stream_response.next().await { - match chunk { - Ok(MultiTurnStreamItem::StreamItem( - StreamedAssistantContent::Text(Text { text }), - )) => { - print!("{text}"); - response.push_str(&text); - } - Ok(MultiTurnStreamItem::FinalResponse(r)) => { - if self.show_usage { - usage = Some(r.usage()); - } - } - Err(e) => { - return Err(PromptError::CompletionError( - CompletionError::ResponseError(e.to_string()), - )); - } - _ => {} - } - } + let response = self.0.request(input, history.clone()).await?; + history.push(Message::user(input)); + history.push(Message::assistant(response)); println!("================================================================"); println!(); - // `with_history` does not push to history, we have handle that - chat_log.push(Message::user(input)); - chat_log.push(Message::assistant(response.clone())); - - if let Some(usage) = usage { - println!( - "Input: {} tokens\nOutput: {} tokens", - usage.input_tokens, usage.output_tokens - ) + if self.0.show_usage() { + let Usage { + input_tokens, + output_tokens, + .. + } = self.0.usage().unwrap(); + println!("Input {input_tokens} tokens\nOutput {output_tokens} tokens"); } - - tracing::info!("Response:\n{}\n", response); } - Err(error) => println!("Error reading input: {error}"), + Err(e) => println!("Error reading request: {e}"), } } Ok(()) } } - -/// Utility function to create a simple REPL CLI chatbot from a type that implements the -/// `Chat` trait. -/// -/// Where the [Chatbot] type takes an agent, this takes any type that implements the [Chat] trait. -pub async fn cli_chatbot(chatbot: impl Chat) -> Result<(), PromptError> { - let stdin = io::stdin(); - let mut stdout = io::stdout(); - let mut chat_log = vec![]; - - println!("Welcome to the chatbot! Type 'exit' to quit."); - loop { - print!("> "); - // Flush stdout to ensure the prompt appears before input - stdout.flush().unwrap(); - - let mut input = String::new(); - match stdin.read_line(&mut input) { - Ok(_) => { - // Remove the newline character from the input - let input = input.trim(); - // Check for a command to exit - if input == "exit" { - break; - } - tracing::info!("Prompt:\n{}\n", input); - - let response = chatbot.chat(input, chat_log.clone()).await?; - chat_log.push(Message::user(input)); - chat_log.push(Message::assistant(response.clone())); - - println!("========================== Response ============================"); - println!("{response}"); - println!("================================================================\n\n"); - - tracing::info!("Response:\n{}\n", response); - } - Err(error) => println!("Error reading input: {error}"), - } - } - - Ok(()) -} diff --git a/rig-core/src/client/builder.rs b/rig-core/src/client/builder.rs index f093d1caa..5c30ac641 100644 --- a/rig-core/src/client/builder.rs +++ b/rig-core/src/client/builder.rs @@ -1,10 +1,12 @@ use crate::agent::Agent; use crate::client::ProviderClient; +use crate::completion::{CompletionRequest, Message}; use crate::embeddings::embedding::EmbeddingModelDyn; use crate::providers::{ anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira, moonshot, ollama, openai, openrouter, perplexity, together, xai, }; +use crate::streaming::StreamingCompletionResponse; use crate::transcription::TranscriptionModelDyn; use rig::completion::CompletionModelDyn; use std::collections::HashMap; @@ -373,6 +375,124 @@ impl<'a> DynClientBuilder { model, }) } + + /// Stream a completion request to the specified provider and model. + /// + /// # Arguments + /// * `provider` - The name of the provider (e.g., "openai", "anthropic") + /// * `model` - The name of the model (e.g., "gpt-4o", "claude-3-sonnet") + /// * `request` - The completion request containing prompt, parameters, etc. + /// + /// # Returns + /// A future that resolves to a streaming completion response + pub async fn stream_completion( + &self, + provider: &str, + model: &str, + request: CompletionRequest, + ) -> Result, ClientBuildError> { + let client = self.build(provider)?; + let completion = client + .as_completion() + .ok_or(ClientBuildError::UnsupportedFeature( + provider.to_string(), + "completion".to_string(), + ))?; + + let model = completion.completion_model(model); + model + .stream(request) + .await + .map_err(|e| ClientBuildError::FactoryError(e.to_string())) + } + + /// Stream a simple prompt to the specified provider and model. + /// + /// # Arguments + /// * `provider` - The name of the provider (e.g., "openai", "anthropic") + /// * `model` - The name of the model (e.g., "gpt-4o", "claude-3-sonnet") + /// * `prompt` - The prompt to send to the model + /// + /// # Returns + /// A future that resolves to a streaming completion response + pub async fn stream_prompt( + &self, + provider: &str, + model: &str, + prompt: impl Into + Send, + ) -> Result, ClientBuildError> { + let client = self.build(provider)?; + let completion = client + .as_completion() + .ok_or(ClientBuildError::UnsupportedFeature( + provider.to_string(), + "completion".to_string(), + ))?; + + let model = completion.completion_model(model); + let request = CompletionRequest { + preamble: None, + tools: vec![], + documents: vec![], + temperature: None, + max_tokens: None, + additional_params: None, + tool_choice: None, + chat_history: crate::OneOrMany::one(prompt.into()), + }; + + model + .stream(request) + .await + .map_err(|e| ClientBuildError::FactoryError(e.to_string())) + } + + /// Stream a chat with history to the specified provider and model. + /// + /// # Arguments + /// * `provider` - The name of the provider (e.g., "openai", "anthropic") + /// * `model` - The name of the model (e.g., "gpt-4o", "claude-3-sonnet") + /// * `prompt` - The new prompt to send to the model + /// * `chat_history` - The chat history to include with the request + /// + /// # Returns + /// A future that resolves to a streaming completion response + pub async fn stream_chat( + &self, + provider: &str, + model: &str, + prompt: impl Into + Send, + chat_history: Vec, + ) -> Result, ClientBuildError> { + let client = self.build(provider)?; + let completion = client + .as_completion() + .ok_or(ClientBuildError::UnsupportedFeature( + provider.to_string(), + "completion".to_string(), + ))?; + + let model = completion.completion_model(model); + let mut history = chat_history; + history.push(prompt.into()); + + let request = CompletionRequest { + preamble: None, + tools: vec![], + documents: vec![], + temperature: None, + max_tokens: None, + additional_params: None, + tool_choice: None, + chat_history: crate::OneOrMany::many(history) + .unwrap_or_else(|_| crate::OneOrMany::one(Message::user(""))), + }; + + model + .stream(request) + .await + .map_err(|e| ClientBuildError::FactoryError(e.to_string())) + } } pub struct ProviderModelId<'builder, 'id> { @@ -397,6 +517,56 @@ impl<'builder> ProviderModelId<'builder, '_> { pub fn transcription(self) -> Result, ClientBuildError> { self.builder.transcription(self.provider, self.model) } + + /// Stream a completion request using this provider and model. + /// + /// # Arguments + /// * `request` - The completion request containing prompt, parameters, etc. + /// + /// # Returns + /// A future that resolves to a streaming completion response + pub async fn stream_completion( + self, + request: CompletionRequest, + ) -> Result, ClientBuildError> { + self.builder + .stream_completion(self.provider, self.model, request) + .await + } + + /// Stream a simple prompt using this provider and model. + /// + /// # Arguments + /// * `prompt` - The prompt to send to the model + /// + /// # Returns + /// A future that resolves to a streaming completion response + pub async fn stream_prompt( + self, + prompt: impl Into + Send, + ) -> Result, ClientBuildError> { + self.builder + .stream_prompt(self.provider, self.model, prompt) + .await + } + + /// Stream a chat with history using this provider and model. + /// + /// # Arguments + /// * `prompt` - The new prompt to send to the model + /// * `chat_history` - The chat history to include with the request + /// + /// # Returns + /// A future that resolves to a streaming completion response + pub async fn stream_chat( + self, + prompt: impl Into + Send, + chat_history: Vec, + ) -> Result, ClientBuildError> { + self.builder + .stream_chat(self.provider, self.model, prompt, chat_history) + .await + } } #[cfg(feature = "image")] diff --git a/rig-core/src/client/mod.rs b/rig-core/src/client/mod.rs index 364433004..3bd6d1685 100644 --- a/rig-core/src/client/mod.rs +++ b/rig-core/src/client/mod.rs @@ -615,6 +615,7 @@ mod tests { temperature: None, max_tokens: None, additional_params: None, + tool_choice: None, chat_history: OneOrMany::one(Message::user("What is the capital of France?")), }); diff --git a/rig-core/src/completion/message.rs b/rig-core/src/completion/message.rs index 7bfd4f2cc..a0fe667e1 100644 --- a/rig-core/src/completion/message.rs +++ b/rig-core/src/completion/message.rs @@ -200,6 +200,10 @@ pub enum DocumentSourceKind { Url(String), /// A base-64 encoded string. Base64(String), + /// Raw bytes + Raw(Vec), + /// A string (or a string literal). + String(String), #[default] /// An unknown file source (there's nothing there). Unknown, @@ -214,6 +218,14 @@ impl DocumentSourceKind { Self::Base64(base64_string.to_string()) } + pub fn raw(bytes: impl Into>) -> Self { + Self::Raw(bytes.into()) + } + + pub fn string(input: &str) -> Self { + Self::String(input.into()) + } + pub fn unknown() -> Self { Self::Unknown } @@ -231,6 +243,8 @@ impl std::fmt::Display for DocumentSourceKind { match self { Self::Url(string) => write!(f, "{string}"), Self::Base64(string) => write!(f, "{string}"), + Self::String(string) => write!(f, "{string}"), + Self::Raw(_) => write!(f, ""), Self::Unknown => write!(f, ""), } } @@ -273,6 +287,7 @@ pub enum ContentFormat { #[default] Base64, String, + Url, } /// Helper enum that tracks the media type of the content. @@ -445,6 +460,20 @@ impl UserContent { }) } + /// Helper constructor to make creating user image content from raw unencoded bytes easier. + pub fn image_raw( + data: impl Into>, + media_type: Option, + detail: Option, + ) -> Self { + UserContent::Image(Image { + data: DocumentSourceKind::Raw(data.into()), + media_type, + detail, + ..Default::default() + }) + } + /// Helper constructor to make creating user image content easier. pub fn image_url( url: impl Into, @@ -468,7 +497,16 @@ impl UserContent { }) } - /// Helper to create an audio resource froma URL + /// Helper constructor to make creating user audio content from raw unencoded bytes easier. + pub fn audio_raw(data: impl Into>, media_type: Option) -> Self { + UserContent::Audio(Audio { + data: DocumentSourceKind::Raw(data.into()), + media_type, + ..Default::default() + }) + } + + /// Helper to create an audio resource from a URL pub fn audio_url(url: impl Into, media_type: Option) -> Self { UserContent::Audio(Audio { data: DocumentSourceKind::Url(url.into()), @@ -478,14 +516,26 @@ impl UserContent { } /// Helper constructor to make creating user document content easier. + /// This creates a document that assumes the data being passed in is a raw string. pub fn document(data: impl Into, media_type: Option) -> Self { + let data: String = data.into(); UserContent::Document(Document { - data: DocumentSourceKind::Base64(data.into()), + data: DocumentSourceKind::string(&data), media_type, additional_params: None, }) } + /// Helper to create a document from raw unencoded bytes + pub fn document_raw(data: impl Into>, media_type: Option) -> Self { + UserContent::Document(Document { + data: DocumentSourceKind::Raw(data.into()), + media_type, + ..Default::default() + }) + } + + /// Helper to create a document from a URL pub fn document_url(url: impl Into, media_type: Option) -> Self { UserContent::Document(Document { data: DocumentSourceKind::Url(url.into()), @@ -576,6 +626,20 @@ impl ToolResultContent { }) } + /// Helper constructor to make tool result images from a base64-encoded string. + pub fn image_raw( + data: impl Into>, + media_type: Option, + detail: Option, + ) -> Self { + ToolResultContent::Image(Image { + data: DocumentSourceKind::Raw(data.into()), + media_type, + detail, + ..Default::default() + }) + } + /// Helper constructor to make tool result images from a URL. pub fn image_url( url: impl Into, @@ -902,6 +966,18 @@ impl From for Message { } } +#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoice { + #[default] + Auto, + None, + Required, + Specific { + function_names: Vec, + }, +} + // ================================================================ // Error types // ================================================================ diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 52d79ba1c..241fe591f 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -65,6 +65,7 @@ use super::message::{AssistantContent, DocumentMediaType}; use crate::client::completion::CompletionModelHandle; +use crate::message::ToolChoice; use crate::streaming::StreamingCompletionResponse; use crate::{OneOrMany, http_client, streaming}; use crate::{ @@ -249,6 +250,19 @@ impl GetTokenUsage for () { } } +impl GetTokenUsage for Option +where + T: GetTokenUsage, +{ + fn token_usage(&self) -> Option { + if let Some(usage) = self { + usage.token_usage() + } else { + None + } + } +} + /// Struct representing the token usage for a completion request. /// If tokens used are `0`, then the provider failed to supply token usage metrics. #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] @@ -416,6 +430,8 @@ pub struct CompletionRequest { pub temperature: Option, /// The max tokens to be sent to the completion model provider pub max_tokens: Option, + /// Whether tools are required to be used by the model provider or not before providing a response. + pub tool_choice: Option, /// Additional provider-specific parameters to be sent to the completion model provider pub additional_params: Option, } @@ -503,6 +519,7 @@ pub struct CompletionRequestBuilder { tools: Vec, temperature: Option, max_tokens: Option, + tool_choice: Option, additional_params: Option, } @@ -517,6 +534,7 @@ impl CompletionRequestBuilder { tools: Vec::new(), temperature: None, max_tokens: None, + tool_choice: None, additional_params: None, } } @@ -624,6 +642,12 @@ impl CompletionRequestBuilder { self } + /// Sets the thing. + pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self { + self.tool_choice = Some(tool_choice); + self + } + /// Builds the completion request. pub fn build(self) -> CompletionRequest { let chat_history = OneOrMany::many([self.chat_history, vec![self.prompt]].concat()) @@ -636,6 +660,7 @@ impl CompletionRequestBuilder { tools: self.tools, temperature: self.temperature, max_tokens: self.max_tokens, + tool_choice: self.tool_choice, additional_params: self.additional_params, } } @@ -718,6 +743,7 @@ mod tests { tools: Vec::new(), temperature: None, max_tokens: None, + tool_choice: None, additional_params: None, }; @@ -747,6 +773,7 @@ mod tests { tools: Vec::new(), temperature: None, max_tokens: None, + tool_choice: None, additional_params: None, }; diff --git a/rig-core/src/extractor.rs b/rig-core/src/extractor.rs index d4fbb1c47..7e4e86c77 100644 --- a/rig-core/src/extractor.rs +++ b/rig-core/src/extractor.rs @@ -37,7 +37,7 @@ use serde_json::json; use crate::{ agent::{Agent, AgentBuilder}, completion::{Completion, CompletionError, CompletionModel, ToolDefinition}, - message::{AssistantContent, Message, ToolCall, ToolFunction}, + message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction}, tool::Tool, }; @@ -87,7 +87,7 @@ where retries = self.retries - i ); let attempt_text = text_message.clone(); - match self.extract_json(attempt_text).await { + match self.extract_json(attempt_text, vec![]).await { Ok(data) => return Ok(data), Err(e) => { tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying..."); @@ -100,8 +100,45 @@ where Err(last_error.unwrap_or(ExtractionError::NoData)) } - async fn extract_json(&self, text: impl Into + Send) -> Result { - let response = self.agent.completion(text, vec![]).await?.send().await?; + /// Attempts to extract data from the given text with a number of retries. + /// + /// The function will retry the extraction if the initial attempt fails or + /// if the model does not call the `submit` tool. + /// + /// The number of retries is determined by the `retries` field on the Extractor struct. + pub async fn extract_with_chat_history( + &self, + text: impl Into + Send, + chat_history: Vec, + ) -> Result { + let mut last_error = None; + let text_message = text.into(); + + for i in 0..=self.retries { + tracing::debug!( + "Attempting to extract JSON. Retries left: {retries}", + retries = self.retries - i + ); + let attempt_text = text_message.clone(); + match self.extract_json(attempt_text, chat_history.clone()).await { + Ok(data) => return Ok(data), + Err(e) => { + tracing::warn!("Attempt {i} to extract JSON failed: {e:?}. Retrying..."); + last_error = Some(e); + } + } + } + + // If the loop finishes without a successful extraction, return the last error encountered. + Err(last_error.unwrap_or(ExtractionError::NoData)) + } + + async fn extract_json( + &self, + text: impl Into + Send, + messages: Vec, + ) -> Result { + let response = self.agent.completion(text, messages).await?.send().await?; if !response.choice.iter().any(|x| { let AssistantContent::ToolCall(ToolCall { @@ -189,7 +226,8 @@ where Use the `submit` function to submit the structured data.\n\ Be sure to fill out every field and ALWAYS CALL THE `submit` function, even with default values!!!. ") - .tool(SubmitTool:: {_t: PhantomData}), + .tool(SubmitTool:: {_t: PhantomData}) + .tool_choice(ToolChoice::Required), retries: None, _t: PhantomData, } @@ -226,6 +264,12 @@ where self } + /// Set the `tool_choice` option for the inner Agent. + pub fn tool_choice(mut self, choice: ToolChoice) -> Self { + self.agent_builder = self.agent_builder.tool_choice(choice); + self + } + /// Build the Extractor pub fn build(self) -> Extractor { Extractor { diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index c05eeb45b..47874a8a7 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -139,3 +139,5 @@ pub use one_or_many::{EmptyListError, OneOrMany}; #[cfg(feature = "derive")] #[cfg_attr(docsrs, doc(cfg(feature = "derive")))] pub use rig_derive::Embed; + +pub mod telemetry; diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index da68995d8..905c06342 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -2,11 +2,12 @@ use crate::{ OneOrMany, - completion::{self, CompletionError}, + completion::{self, CompletionError, GetTokenUsage}, http_client::HttpClientExt, json_utils, message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning}, one_or_many::string_or_one_or_many, + telemetry::{ProviderResponseExt, SpanCombinator}, }; use std::{convert::Infallible, str::FromStr}; @@ -16,6 +17,7 @@ use crate::providers::anthropic::streaming::StreamingCompletionResponse; use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::json; +use tracing::{Instrument, info_span}; // ================================================================ // Anthropic Completion API @@ -60,7 +62,45 @@ pub struct CompletionResponse { pub usage: Usage, } -#[derive(Debug, Deserialize, Serialize)] +impl ProviderResponseExt for CompletionResponse { + type OutputMessage = Content; + type Usage = Usage; + + fn get_response_id(&self) -> Option { + Some(self.id.to_owned()) + } + + fn get_response_model_name(&self) -> Option { + Some(self.model.to_owned()) + } + + fn get_output_messages(&self) -> Vec { + self.content.clone() + } + + fn get_text_response(&self) -> Option { + let res = self + .content + .iter() + .filter_map(|x| { + if let Content::Text { text } = x { + Some(text.to_owned()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + if res.is_empty() { None } else { Some(res) } + } + + fn get_usage(&self) -> Option { + Some(self.usage.clone()) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Usage { pub input_tokens: u64, pub cache_read_input_tokens: Option, @@ -87,6 +127,20 @@ impl std::fmt::Display for Usage { } } +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.input_tokens + + self.cache_creation_input_tokens.unwrap_or_default() + + self.cache_read_input_tokens.unwrap_or_default(); + usage.output_tokens = self.output_tokens; + usage.total_tokens = usage.input_tokens + usage.output_tokens; + + Some(usage) + } +} + #[derive(Debug, Deserialize, Serialize)] pub struct ToolDefinition { pub name: String, @@ -307,7 +361,10 @@ impl TryFrom for SourceType { fn try_from(format: message::ContentFormat) -> Result { match format { message::ContentFormat::Base64 => Ok(SourceType::BASE64), - message::ContentFormat::String => Ok(SourceType::URL), + message::ContentFormat::Url => Ok(SourceType::URL), + message::ContentFormat::String => Err(MessageError::ConversionError( + "ContentFormat::String is deprecated, use ContentFormat::Url for URLs".into(), + )), } } } @@ -316,7 +373,7 @@ impl From for message::ContentFormat { fn from(source_type: SourceType) -> Self { match source_type { SourceType::BASE64 => message::ContentFormat::Base64, - SourceType::URL => message::ContentFormat::String, + SourceType::URL => message::ContentFormat::Url, } } } @@ -446,6 +503,11 @@ impl TryFrom for Message { "Image content has no body".into(), )); } + doc => { + return Err(MessageError::ConversionError(format!( + "Unsupported document type: {doc:?}" + ))); + } }; Ok(Content::Image { source }) @@ -457,10 +519,15 @@ impl TryFrom for Message { "Document media type is required".to_string(), ))?; - let DocumentSourceKind::Base64(data) = data else { - return Err(MessageError::ConversionError( - "Only base64 encoded documents currently supported".into(), - )); + let data = match data { + DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => { + data + } + _ => { + return Err(MessageError::ConversionError( + "Only base64 encoded documents currently supported".into(), + )); + } }; let source = DocumentSource { @@ -630,11 +697,35 @@ pub enum ToolChoice { #[default] Auto, Any, + None, Tool { name: String, }, } +impl TryFrom for ToolChoice { + type Error = CompletionError; + fn try_from(value: message::ToolChoice) -> Result { + let res = match value { + message::ToolChoice::Auto => Self::Auto, + message::ToolChoice::None => Self::None, + message::ToolChoice::Required => Self::Any, + message::ToolChoice::Specific { function_names } => { + if function_names.len() != 1 { + return Err(CompletionError::ProviderError( + "Only one tool may be specified to be used by Claude".into(), + )); + } + + Self::Tool { + name: function_names.first().unwrap().to_string(), + } + } + }; + + Ok(res) + } +} impl completion::CompletionModel for CompletionModel where T: HttpClientExt + Clone + Default, @@ -647,6 +738,24 @@ where &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "anthropic", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; // Note: Ideally we'd introduce provider-specific Request models to handle the // specific requirements of each provider. For now, we just manually check while // building the request as a raw JSON document. @@ -667,6 +776,7 @@ where full_history.push(docs); } full_history.extend(completion_request.chat_history); + span.record_model_input(&full_history); let full_history = full_history .into_iter() @@ -684,6 +794,12 @@ where json_utils::merge_inplace(&mut request, json!({ "temperature": temperature })); } + let tool_choice = if let Some(tool_choice) = completion_request.tool_choice { + Some(ToolChoice::try_from(tool_choice)?) + } else { + None + }; + if !completion_request.tools.is_empty() { json_utils::merge_inplace( &mut request, @@ -697,7 +813,7 @@ where input_schema: tool.parameters, }) .collect::>(), - "tool_choice": ToolChoice::Auto, + "tool_choice": tool_choice, }), ); } @@ -706,55 +822,54 @@ where json_utils::merge_inplace(&mut request, params.clone()) } - tracing::debug!("Anthropic completion request: {request}"); - - let request: Vec = serde_json::to_vec(&request)?; - - let req = self - .client - .post("/v1/messages") - .body(request) - .map_err(|e| CompletionError::HttpError(e.into()))?; - - let response = self - .client - .send::<_, Bytes>(req) - .await - .map_err(CompletionError::HttpError)?; - - if response.status().is_success() { - match serde_json::from_slice::>( - response - .into_body() - .await - .map_err(CompletionError::HttpError)? - .to_vec() - .as_slice(), - )? { - ApiResponse::Message(completion) => { - let completion: Result, _> = - completion.try_into(); - - tracing::info!( - target: "rig", - "Anthropic completion token usage: {:?}", - completion - ); - - completion + async move { + let request: Vec = serde_json::to_vec(&request)?; + + let req = self + .client + .post("/v1/messages") + .body(request) + .map_err(|e| CompletionError::HttpError(e.into()))?; + + let response = self + .client + .send::<_, Bytes>(req) + .await + .map_err(CompletionError::HttpError)?; + + if response.status().is_success() { + match serde_json::from_slice::>( + response + .into_body() + .await + .map_err(CompletionError::HttpError)? + .to_vec() + .as_slice(), + )? { + ApiResponse::Message(completion) => { + let span = tracing::Span::current(); + span.record_model_output(&completion.content); + span.record_response_metadata(&completion); + span.record_token_usage(&completion.usage); + completion.try_into() + } + ApiResponse::Error(ApiErrorResponse { message }) => { + Err(CompletionError::ResponseError(message)) + } } - ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)), + } else { + let text: String = String::from_utf8_lossy( + &response + .into_body() + .await + .map_err(CompletionError::HttpError)?, + ) + .into(); + Err(CompletionError::ProviderError(text)) } - } else { - let text: String = String::from_utf8_lossy( - &response - .into_body() - .await - .map_err(CompletionError::HttpError)?, - ) - .into(); - Err(CompletionError::ProviderError(text)) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] @@ -1031,7 +1146,7 @@ mod tests { }) => { assert_eq!( data, - DocumentSourceKind::Base64("base64_encoded_pdf_data".into()) + DocumentSourceKind::String("base64_encoded_pdf_data".into()) ); assert_eq!(media_type, Some(message::DocumentMediaType::PDF)); } @@ -1086,4 +1201,30 @@ mod tests { assert_eq!(assistant_message, original_assistant_message); assert_eq!(tool_message, original_tool_message); } + + #[test] + fn test_content_format_conversion() { + use crate::completion::message::ContentFormat; + + let source_type: SourceType = ContentFormat::Url.try_into().unwrap(); + assert_eq!(source_type, SourceType::URL); + + let content_format: ContentFormat = SourceType::URL.into(); + assert_eq!(content_format, ContentFormat::Url); + + let source_type: SourceType = ContentFormat::Base64.try_into().unwrap(); + assert_eq!(source_type, SourceType::BASE64); + + let content_format: ContentFormat = SourceType::BASE64.into(); + assert_eq!(content_format, ContentFormat::Base64); + + let result: Result = ContentFormat::String.try_into(); + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("ContentFormat::String is deprecated") + ); + } } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index da9559da7..5668c48ea 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -2,14 +2,17 @@ use async_stream::stream; use futures::StreamExt; use serde::{Deserialize, Serialize}; use serde_json::json; +use tracing::info_span; +use tracing_futures::Instrument; use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage}; use super::decoders::sse::from_response as sse_from_response; +use crate::OneOrMany; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge_inplace; -use crate::streaming; -use crate::streaming::{RawStreamingChoice, StreamingResult}; +use crate::streaming::{self, RawStreamingChoice, StreamingResult}; +use crate::telemetry::SpanCombinator; #[derive(Debug, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -69,6 +72,17 @@ pub struct PartialUsage { pub input_tokens: Option, } +impl GetTokenUsage for PartialUsage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.input_tokens.unwrap_or_default() as u64; + usage.output_tokens = self.output_tokens as u64; + usage.total_tokens = usage.input_tokens + usage.output_tokens; + Some(usage) + } +} + #[derive(Default)] struct ToolCallState { name: String, @@ -102,6 +116,24 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "anthropic", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; let max_tokens = if let Some(tokens) = completion_request.max_tokens { tokens } else if let Some(tokens) = self.default_max_tokens { @@ -117,6 +149,7 @@ where full_history.push(docs); } full_history.extend(completion_request.chat_history); + span.record_model_input(&full_history); let full_history = full_history .into_iter() @@ -192,6 +225,8 @@ where let mut sse_stream = Box::pin(stream); let mut input_tokens = 0; + let mut text_content = String::new(); + while let Some(sse_result) = sse_stream.next().await { match sse_result { Ok(sse) => { @@ -201,14 +236,27 @@ where match &event { StreamingEvent::MessageStart { message } => { input_tokens = message.usage.input_tokens; + + let span = tracing::Span::current(); + span.record("gen_ai.response.id", &message.id); + span.record("gen_ai.response.model_name", &message.model); }, StreamingEvent::MessageDelta { delta, usage } => { if delta.stop_reason.is_some() { + let usage = PartialUsage { + output_tokens: usage.output_tokens, + input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")), + }; + + let span = tracing::Span::current(); + span.record_token_usage(&usage); + span.record_model_output(&Message { + role: super::completion::Role::Assistant, + content: OneOrMany::one(Content::Text { text: text_content.clone() })} + ); + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage: PartialUsage { - output_tokens: usage.output_tokens, - input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")), - } + usage })) } } @@ -216,6 +264,9 @@ where } if let Some(result) = handle_event(&event, &mut current_tool_call) { + if let Ok(RawStreamingChoice::Message(ref text)) = result { + text_content += text; + } yield result; } }, @@ -234,7 +285,7 @@ where } } } - }); + }.instrument(span)); Ok(streaming::StreamingCompletionResponse::stream(stream)) } diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index e648b3995..d0e1d74c1 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -11,6 +11,7 @@ use super::openai::{TranscriptionResponse, send_compatible_streaming_request}; +use crate::completion::GetTokenUsage; use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; use crate::streaming::StreamingCompletionResponse; @@ -19,6 +20,7 @@ use crate::{ embeddings::{self, EmbeddingError}, json_utils, providers::openai, + telemetry::SpanCombinator, transcription::{self, TranscriptionError}, }; use bytes::Bytes; @@ -424,6 +426,18 @@ pub struct Usage { pub total_tokens: usize, } +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.prompt_tokens as u64; + usage.total_tokens = self.total_tokens as u64; + usage.output_tokens = usage.total_tokens - usage.input_tokens; + + Some(usage) + } +} + impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -619,41 +633,66 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "azure.openai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; let request = self.create_completion_request(completion_request)?; + span.record_model_input( + &request + .get("messages") + .expect("Converting JSON should not fail"), + ); - let response = self - .client - .post_chat_completion(&self.model) - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - - if response.status().is_success() { - let t = response - .text() + async move { + let response = self + .client + .post_chat_completion(&self.model) + .json(&request) + .send() .await .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - tracing::debug!(target: "rig", "Azure completion error: {}", t); - - match serde_json::from_str::>(&t)? { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "Azure completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - response.try_into() + if response.status().is_success() { + let t = response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?; + tracing::debug!(target: "rig", "Azure completion error: {}", t); + + match serde_json::from_str::>(&t)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record_model_output(&response.choices); + span.record_response_metadata(&response); + span.record_token_usage(&response.usage); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] @@ -661,6 +700,7 @@ impl completion::CompletionModel for CompletionModel { &self, request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = request.preamble.clone(); let mut request = self.create_completion_request(request)?; request = merge( @@ -673,7 +713,27 @@ impl completion::CompletionModel for CompletionModel { .post_chat_completion(self.model.as_str()) .json(&request); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "azure.openai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing_futures::Instrument::instrument(send_compatible_streaming_request(builder), span) + .await } } @@ -767,6 +827,7 @@ impl transcription::TranscriptionModel for TranscriptionModel { // ================================================================ #[cfg(feature = "image")] pub use image_generation::*; +use tracing::{Instrument, info_span}; #[cfg(feature = "image")] #[cfg_attr(docsrs, doc(cfg(feature = "image")))] mod image_generation { @@ -974,6 +1035,7 @@ mod azure_tests { max_tokens: Some(100), temperature: Some(0.0), tools: vec![], + tool_choice: None, additional_params: None, }) .await diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index a7d3d299e..e9e2f19fb 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -1,9 +1,10 @@ use crate::{ OneOrMany, - completion::{self, CompletionError}, + completion::{self, CompletionError, GetTokenUsage}, http_client::{self, HttpClientExt}, json_utils, - message::{self, Reasoning}, + message::{self, Reasoning, ToolChoice}, + telemetry::SpanCombinator, }; use std::collections::HashMap; @@ -12,6 +13,7 @@ use crate::completion::CompletionRequest; use crate::providers::cohere::streaming::StreamingCompletionResponse; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use tracing::{Instrument, info_span}; #[derive(Debug, Deserialize, Serialize)] pub struct CompletionResponse { @@ -39,6 +41,47 @@ impl CompletionResponse { } } +impl crate::telemetry::ProviderResponseExt for CompletionResponse { + type OutputMessage = Message; + type Usage = Usage; + + fn get_response_id(&self) -> Option { + Some(self.id.clone()) + } + + fn get_response_model_name(&self) -> Option { + None + } + + fn get_output_messages(&self) -> Vec { + vec![self.message.clone()] + } + + fn get_text_response(&self) -> Option { + let Message::Assistant { ref content, .. } = self.message else { + return None; + }; + + let res = content + .iter() + .filter_map(|x| { + if let AssistantContent::Text { text } = x { + Some(text.to_string()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + if res.is_empty() { None } else { Some(res) } + } + + fn get_usage(&self) -> Option { + self.usage.clone() + } +} + #[derive(Debug, Deserialize, PartialEq, Eq, Clone, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum FinishReason { @@ -57,6 +100,20 @@ pub struct Usage { pub tokens: Option, } +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + if let Some(ref billed_units) = self.billed_units { + usage.input_tokens = billed_units.input_tokens.unwrap_or_default() as u64; + usage.output_tokens = billed_units.output_tokens.unwrap_or_default() as u64; + usage.total_tokens = usage.input_tokens + usage.output_tokens; + } + + Some(usage) + } +} + #[derive(Debug, Deserialize, Clone, Serialize)] pub struct BilledUnits { #[serde(default)] @@ -507,6 +564,9 @@ where "documents": completion_request.documents, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(Tool::from).collect::>(), + "tool_choice": if let Some(tool_choice) = completion_request.tool_choice && !matches!(tool_choice, ToolChoice::Auto) { tool_choice } else { + return Err(CompletionError::RequestError("\"auto\" is not an allowed tool_choice value in the Cohere API".into())) + }, }); if let Some(ref params) = completion_request.additional_params { @@ -527,39 +587,65 @@ impl completion::CompletionModel for CompletionModel { completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { let request = self.create_completion_request(completion_request)?; + + let llm_span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "cohere", + gen_ai.request.model = self.model, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(request.get("messages").expect("Converting request messages to JSON should not fail!")).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + tracing::debug!( "Cohere request: {}", serde_json::to_string_pretty(&request)? ); - let response = self - .client - .client() - .post("/v2/chat") - .json(&request) - .send() - .await - .map_err(|e| http_client::Error::Instance(e.into()))?; - - if response.status().is_success() { - let text = response - .text() + async { + let response = self + .client + .client() + .post("/v2/chat") + .json(&request) + .send() .await - .map_err(|e| CompletionError::ResponseError(e.to_string()))?; - - tracing::debug!("Cohere response text: {}", text); + .map_err(|e| http_client::Error::Instance(e.into()))?; - let json_response: CompletionResponse = serde_json::from_str(&text)?; - let completion: completion::CompletionResponse = - json_response.try_into()?; - Ok(completion) - } else { - let text = response - .text() - .await - .map_err(|e| CompletionError::ResponseError(e.to_string()))?; - Err(CompletionError::ProviderError(text.to_string())) + if response.status().is_success() { + let text_response = response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?; + tracing::debug!("Cohere completion request: {}", text_response); + + let json_response: CompletionResponse = serde_json::from_str(&text_response)?; + let span = tracing::Span::current(); + span.record_token_usage(&json_response.usage); + span.record_model_output(&json_response.message); + span.record_response_metadata(&json_response); + tracing::debug!("Cohere completion response: {}", text_response); + let completion: completion::CompletionResponse = + json_response.try_into()?; + Ok(completion) + } else { + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) + } } + .instrument(llm_span) + .await } #[cfg_attr(feature = "worker", worker::send)] diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index f15a310c3..f62ca4bba 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -1,12 +1,17 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; use crate::providers::cohere::CompletionModel; -use crate::providers::cohere::completion::Usage; +use crate::providers::cohere::completion::{ + AssistantContent, Message, ToolCall, ToolCallFunction, ToolType, Usage, +}; use crate::streaming::RawStreamingChoice; +use crate::telemetry::SpanCombinator; use crate::{json_utils, streaming}; use async_stream::stream; use futures::StreamExt; use reqwest_eventsource::Event; use serde::{Deserialize, Serialize}; +use tracing::info_span; +use tracing_futures::Instrument; #[derive(Debug, Deserialize)] #[serde(rename_all = "kebab-case", tag = "type")] @@ -91,10 +96,28 @@ impl CompletionModel { ) -> Result, CompletionError> { let request = self.create_completion_request(request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "cohere", + gen_ai.request.model = self.model, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + let request = json_utils::merge(request, serde_json::json!({"stream": true})); tracing::debug!( - "Cohere request: {}", + "Cohere streaming completion input: {}", serde_json::to_string_pretty(&request)? ); @@ -108,6 +131,8 @@ impl CompletionModel { let stream = Box::pin(stream! { let mut current_tool_call: Option<(String, String, String)> = None; + let mut text_response = String::new(); + let mut tool_calls = Vec::new(); while let Some(event_result) = event_source.next().await { match event_result { @@ -136,10 +161,23 @@ impl CompletionModel { let Some(content) = &message.content else { continue; }; let Some(text) = &content.text else { continue; }; + text_response += text; + yield Ok(RawStreamingChoice::Message(text.clone())); }, StreamingEvent::MessageEnd { delta: Some(delta) } => { + let message = Message::Assistant { + tool_calls: tool_calls.clone(), + content: vec![AssistantContent::Text { text: text_response.clone() }], + tool_plan: None, + citations: vec![] + }; + + let span = tracing::Span::current(); + span.record_token_usage(&delta.usage); + span.record_model_output(&vec![message]); + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { usage: delta.usage.clone() })); @@ -168,7 +206,16 @@ impl CompletionModel { StreamingEvent::ToolCallEnd => { let Some(tc) = current_tool_call.clone() else { continue; }; - let Ok(args) = serde_json::from_str(&tc.2) else { continue; }; + let Ok(args) = serde_json::from_str::(&tc.2) else { continue; }; + + tool_calls.push(ToolCall { + id: Some(tc.0.clone()), + r#type: Some(ToolType::Function), + function: Some(ToolCallFunction { + name: tc.1.clone(), + arguments: args.clone() + }) + }); yield Ok(RawStreamingChoice::ToolCall { id: tc.0, @@ -195,7 +242,7 @@ impl CompletionModel { } event_source.close(); - }); + }.instrument(span)); Ok(streaming::StreamingCompletionResponse::stream(stream)) } diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 780273df4..5bc45a74e 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -14,6 +14,7 @@ use bytes::Bytes; use futures::StreamExt; use reqwest_eventsource::{Event, RequestBuilderExt}; use std::collections::HashMap; +use tracing::{Instrument, info_span}; use crate::client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}; use crate::completion::GetTokenUsage; @@ -400,7 +401,9 @@ impl TryFrom for Vec { name: None, }), message::UserContent::Document(Document { - data: DocumentSourceKind::Base64(content), + data: + DocumentSourceKind::Base64(content) + | DocumentSourceKind::String(content), .. }) => Some(Message::User { content, @@ -416,6 +419,22 @@ impl TryFrom for Vec { message::Message::Assistant { content, .. } => { let mut messages: Vec = vec![]; + // extract text + let text_content = content + .clone() + .into_iter() + .filter_map(|content| match content { + message::AssistantContent::Text(text) => Some(Message::Assistant { + content: text.text, + name: None, + tool_calls: vec![], + }), + _ => None, + }) + .collect::>(); + + messages.extend(text_content); + // extract tool calls let tool_calls = content .clone() @@ -437,21 +456,6 @@ impl TryFrom for Vec { }); } - // extract text - let text_content = content - .into_iter() - .filter_map(|content| match content { - message::AssistantContent::Text(text) => Some(Message::Assistant { - content: text.text, - name: None, - tool_calls: vec![], - }), - _ => None, - }) - .collect::>(); - - messages.extend(text_content); - Ok(messages) } } @@ -591,6 +595,11 @@ impl CompletionModel { .collect::>(), ); + let tool_choice = completion_request + .tool_choice + .map(crate::providers::openrouter::ToolChoice::try_from) + .transpose()?; + let request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -603,7 +612,7 @@ impl CompletionModel { "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; @@ -629,37 +638,74 @@ impl completion::CompletionModel for CompletionModel { completion::CompletionResponse, crate::completion::CompletionError, > { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "deepseek", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + tracing::debug!("DeepSeek completion request: {request:?}"); - let response = self - .client - .reqwest_post("/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| http_client::Error::Instance(e.into()))?; - - if response.status().is_success() { - let t: String = response - .text() + async move { + let response = self + .client + .reqwest_post("/chat/completions") + .json(&request) + .send() .await .map_err(|e| http_client::Error::Instance(e.into()))?; - tracing::debug!(target: "rig", "DeepSeek completion: {}", t); - match serde_json::from_str::>(&t)? { - ApiResponse::Ok(response) => response.try_into(), - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), - } - } else { - Err(CompletionError::ProviderError( - response + if response.status().is_success() { + let t = response .text() .await - .map_err(|e| http_client::Error::Instance(e.into()))?, - )) + .map_err(|e| http_client::Error::Instance(e.into()))?; + + tracing::debug!(target: "rig", "DeepSeek completion: {t}"); + + match serde_json::from_str::>(&t)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&response.choices).unwrap(), + ); + span.record("gen_ai.usage.input_tokens", response.usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + response.usage.completion_tokens, + ); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )) + } } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] @@ -670,6 +716,7 @@ impl completion::CompletionModel for CompletionModel { crate::streaming::StreamingCompletionResponse, CompletionError, > { + let preamble = completion_request.preamble.clone(); let mut request = self.create_completion_request(completion_request)?; request = merge( @@ -678,7 +725,27 @@ impl completion::CompletionModel for CompletionModel { ); let builder = self.client.reqwest_post("/chat/completions").json(&request); - send_compatible_streaming_request(builder).await + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "deepseek", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await } } @@ -724,12 +791,14 @@ pub async fn send_compatible_streaming_request( crate::streaming::StreamingCompletionResponse, CompletionError, > { + let span = tracing::Span::current(); let mut event_source = request_builder .eventsource() .expect("Cloning request must succeed"); let stream = Box::pin(stream! { let mut final_usage = Usage::new(); + let mut text_response = String::new(); let mut calls: HashMap = HashMap::new(); while let Some(event_result) = event_source.next().await { @@ -806,6 +875,7 @@ pub async fn send_compatible_streaming_request( } if let Some(content) = &delta.content { + text_response += content; yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone())); } } @@ -825,11 +895,22 @@ pub async fn send_compatible_streaming_request( } } + let mut tool_calls = Vec::new(); // Flush accumulated tool calls - for (_, (id, name, arguments)) in calls { + for (index, (id, name, arguments)) in calls { let Ok(arguments_json) = serde_json::from_str::(&arguments) else { continue; }; + + tool_calls.push(ToolCall { + id: id.clone(), + index, + r#type: ToolType::Function, + function: Function { + name: name.clone(), + arguments: arguments_json.clone() + } + }); yield Ok(crate::streaming::RawStreamingChoice::ToolCall { id, name, @@ -838,6 +919,14 @@ pub async fn send_compatible_streaming_request( }); } + let message = Message::Assistant { + content: text_response, + name: None, + tool_calls + }; + + span.record("gen_ai.output.messages", serde_json::to_string(&message).unwrap()); + yield Ok(crate::streaming::RawStreamingChoice::FinalResponse( StreamingCompletionResponse { usage: final_usage.clone() } )); diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index b2b910db4..c29858970 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -25,6 +25,7 @@ use crate::{ use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use tracing::{Instrument, info_span}; // ================================================================ // Main Galadriel Client @@ -549,6 +550,12 @@ where .collect::, _>>()?, ); + let tool_choice = completion_request + .tool_choice + .clone() + .map(crate::providers::openai::completion::ToolChoice::try_from) + .transpose()?; + let request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -561,7 +568,7 @@ where "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; @@ -584,8 +591,9 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let body = self.create_completion_request(completion_request)?; - let body = serde_json::to_vec(&body)?; + let preamble = completion_request.preamble.clone(); + let request = self.create_completion_request(completion_request)?; + let body = serde_json::to_vec(&request)?; let req = self .client @@ -593,27 +601,60 @@ impl completion::CompletionModel for CompletionModel { .body(body) .map_err(http_client::Error::from)?; - let response = self.client.send(req).await?; - - if response.status().is_success() { - let text = http_client::text(response).await?; - - tracing::debug!(target: "rig", "Galadriel completion error: {}", text); + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "galadriel", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - match serde_json::from_str::>(&text)? { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "Galadriel completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - response.try_into() + async move { + let response = self.client.send(req).await?; + + if response.status().is_success() { + let t = http_client::text(response).await?; + tracing::debug!(target: "rig::completions", "Galadriel completion response: {t}"); + + match serde_json::from_str::>(&t)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record("gen_ai.response.id", response.id.clone()); + span.record("gen_ai.response.model_name", response.model.clone()); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&response.choices).unwrap(), + ); + if let Some(ref usage) = response.usage { + span.record("gen_ai.usage.input_tokens", usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + usage.total_tokens - usage.prompt_tokens, + ); + } + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + let text = http_client::text(response).await?; + + Err(CompletionError::ProviderError(text)) } - } else { - let text = http_client::text(response).await?; - Err(CompletionError::ProviderError(text)) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] @@ -621,6 +662,7 @@ impl completion::CompletionModel for CompletionModel { &self, request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = request.preamble.clone(); let mut request = self.create_completion_request(request)?; request = merge( @@ -630,6 +672,27 @@ impl completion::CompletionModel for CompletionModel { let builder = self.client.reqwest_post("/chat/completions").json(&request); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "galadriel", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + send_compatible_streaming_request(builder) + .instrument(span) + .await } } diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index 70360f3b4..f703c98fd 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -29,8 +29,11 @@ pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro"; use self::gemini_api_types::Schema; use crate::message::Reasoning; -use crate::providers::gemini::completion::gemini_api_types::AdditionalParameters; +use crate::providers::gemini::completion::gemini_api_types::{ + AdditionalParameters, FunctionCallingMode, ToolConfig, +}; use crate::providers::gemini::streaming::StreamingCompletionResponse; +use crate::telemetry::SpanCombinator; use crate::{ OneOrMany, completion::{self, CompletionError, CompletionRequest}, @@ -41,6 +44,7 @@ use gemini_api_types::{ }; use serde_json::{Map, Value}; use std::convert::TryFrom; +use tracing::info_span; use super::Client; @@ -72,14 +76,36 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let body = create_request_body(completion_request) - .and_then(|body| serde_json::to_vec(&body).map_err(Into::into))?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "generate_content", + gen_ai.operation.name = "generate_content", + gen_ai.provider.name = "gcp.gemini", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + let request = create_request_body(completion_request)?; + + span.record_model_input(&request.contents); tracing::debug!( "Sending completion request to Gemini API {}", - String::from_utf8_lossy(&body) + serde_json::to_string_pretty(&request)? ); + let body = serde_json::to_vec(&request)?; + let request = self .client .post(&format!("/v1beta/models/{}:generateContent", self.model)) @@ -94,9 +120,9 @@ impl completion::CompletionModel for CompletionModel { .await .map_err(CompletionError::HttpError)?; - let body: GenerateContentResponse = serde_json::from_slice(&response_body)?; + let response: GenerateContentResponse = serde_json::from_slice(&response_body)?; - match body.usage_metadata { + match response.usage_metadata { Some(ref usage) => tracing::info!(target: "rig", "Gemini completion token usage: {}", usage @@ -106,9 +132,17 @@ impl completion::CompletionModel for CompletionModel { ), } - tracing::debug!("Received response"); + let span = tracing::Span::current(); + span.record_model_output(&response.candidates); + span.record_response_metadata(&response); + span.record_token_usage(&response.usage_metadata); + + tracing::debug!( + "Received response from Gemini API: {}", + serde_json::to_string_pretty(&response)? + ); - Ok(completion::CompletionResponse::try_from(body)?) + response.try_into() } else { let text = String::from_utf8_lossy( &response @@ -168,6 +202,14 @@ pub(crate) fn create_request_body( Some(Tool::try_from(completion_request.tools)?) }; + let tool_config = if let Some(cfg) = completion_request.tool_choice { + Some(ToolConfig { + function_calling_config: Some(FunctionCallingMode::try_from(cfg)?), + }) + } else { + None + }; + let request = GenerateContentRequest { contents: full_history .into_iter() @@ -179,7 +221,7 @@ pub(crate) fn create_request_body( generation_config: Some(generation_config), safety_settings: None, tools, - tool_config: None, + tool_config, system_instruction, additional_params, }; @@ -310,6 +352,7 @@ impl TryFrom for completion::CompletionResponse, /// Returns the prompt's feedback related to the content filters. @@ -367,8 +412,60 @@ pub mod gemini_api_types { pub model_version: Option, } + impl ProviderResponseExt for GenerateContentResponse { + type OutputMessage = ContentCandidate; + type Usage = UsageMetadata; + + fn get_response_id(&self) -> Option { + Some(self.response_id.clone()) + } + + fn get_response_model_name(&self) -> Option { + None + } + + fn get_output_messages(&self) -> Vec { + self.candidates.clone() + } + + fn get_text_response(&self) -> Option { + let str = self + .candidates + .iter() + .filter_map(|x| { + if x.content.role.as_ref().is_none_or(|y| y != &Role::Model) { + return None; + } + + let res = x + .content + .parts + .iter() + .filter_map(|part| { + if let PartKind::Text(ref str) = part.part { + Some(str.to_owned()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + Some(res) + }) + .collect::>() + .join("\n"); + + if str.is_empty() { None } else { Some(str) } + } + + fn get_usage(&self) -> Option { + self.usage_metadata.clone() + } + } + /// A response candidate generated from the model. - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct ContentCandidate { /// Output only. Generated content returned from the model. @@ -393,7 +490,7 @@ pub mod gemini_api_types { pub index: Option, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Content { /// Ordered Parts that constitute a single message. Parts may have different MIME types. #[serde(default)] @@ -600,8 +697,15 @@ pub mod gemini_api_types { mime_type: Some(mime_type), file_uri: url, }), - DocumentSourceKind::Base64(data) => PartKind::InlineData(Blob { mime_type, data }), - _ => { + DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => { + PartKind::InlineData(Blob { mime_type, data }) + } + DocumentSourceKind::Raw(_) => { + return Err(message::MessageError::ConversionError( + "Raw files not supported, encode as base64 first".into(), + )); + } + DocumentSourceKind::Unknown => { return Err(message::MessageError::ConversionError( "Can't convert an unknown document source".to_string(), )); @@ -679,9 +783,11 @@ pub mod gemini_api_types { message::UserContent::Document(message::Document { data, media_type, .. }) => { - let media_type = media_type.ok_or(message::MessageError::ConversionError( - "Media type for document is required for Gemini".to_string(), - ))?; + let Some(media_type) = media_type else { + return Err(MessageError::ConversionError( + "A mime type is required for document inputs to Gemini".to_string(), + )); + }; if !media_type.is_code() { let mime_type = media_type.to_mime_type().to_string(); @@ -691,8 +797,13 @@ pub mod gemini_api_types { mime_type: Some(mime_type), file_uri, }), - DocumentSourceKind::Base64(data) => { - PartKind::InlineData(Blob { data, mime_type }) + DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data) => { + PartKind::InlineData(Blob { mime_type, data }) + } + DocumentSourceKind::Raw(_) => { + return Err(message::MessageError::ConversionError( + "Raw files not supported, encode as base64 first".into(), + )); } _ => { return Err(message::MessageError::ConversionError( @@ -716,9 +827,11 @@ pub mod gemini_api_types { message::UserContent::Audio(message::Audio { data, media_type, .. }) => { - let media_type = media_type.ok_or(message::MessageError::ConversionError( - "Media type for audio is required for Gemini".to_string(), - ))?; + let Some(media_type) = media_type else { + return Err(MessageError::ConversionError( + "A mime type is required for audio inputs to Gemini".to_string(), + )); + }; let mime_type = media_type.to_mime_type().to_string(); @@ -726,11 +839,22 @@ pub mod gemini_api_types { DocumentSourceKind::Base64(data) => { PartKind::InlineData(Blob { data, mime_type }) } + DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData { mime_type: Some(mime_type), file_uri, }), - _ => { + DocumentSourceKind::String(_) => { + return Err(message::MessageError::ConversionError( + "Strings cannot be used as audio files!".into(), + )); + } + DocumentSourceKind::Raw(_) => { + return Err(message::MessageError::ConversionError( + "Raw files not supported, encode as base64 first".into(), + )); + } + DocumentSourceKind::Unknown => { return Err(message::MessageError::ConversionError( "Content has no body".to_string(), )); @@ -749,21 +873,49 @@ pub mod gemini_api_types { additional_params, .. }) => { - let media_type = media_type.ok_or(message::MessageError::ConversionError( - "Media type for video is required for Gemini".to_string(), - ))?; - - let mime_type = media_type.to_mime_type().to_owned(); + let mime_type = media_type.map(|media_ty| media_ty.to_mime_type().to_string()); let part = match data { - DocumentSourceKind::Url(file_uri) => PartKind::FileData(FileData { - mime_type: Some(mime_type), - file_uri, - }), + DocumentSourceKind::Url(file_uri) => { + if file_uri.starts_with("https://www.youtube.com") { + PartKind::FileData(FileData { + mime_type, + file_uri, + }) + } else { + if mime_type.is_none() { + return Err(MessageError::ConversionError( + "A mime type is required for non-Youtube video file inputs to Gemini" + .to_string(), + )); + } + + PartKind::FileData(FileData { + mime_type, + file_uri, + }) + } + } DocumentSourceKind::Base64(data) => { + let Some(mime_type) = mime_type else { + return Err(MessageError::ConversionError( + "A media type is expected for base64 encoded strings" + .to_string(), + )); + }; PartKind::InlineData(Blob { mime_type, data }) } - _ => { + DocumentSourceKind::String(_) => { + return Err(message::MessageError::ConversionError( + "Strings cannot be used as audio files!".into(), + )); + } + DocumentSourceKind::Raw(_) => { + return Err(message::MessageError::ConversionError( + "Raw file data not supported, encode as base64 first".into(), + )); + } + DocumentSourceKind::Unknown => { return Err(message::MessageError::ConversionError( "Media type for video is required for Gemini".to_string(), )); @@ -946,6 +1098,21 @@ pub mod gemini_api_types { } } + impl GetTokenUsage for UsageMetadata { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.prompt_token_count as u64; + usage.output_tokens = (self.cached_content_token_count.unwrap_or_default() + + self.candidates_token_count.unwrap_or_default() + + self.thoughts_token_count.unwrap_or_default()) + as u64; + usage.total_tokens = usage.input_tokens + usage.output_tokens; + + Some(usage) + } + } + /// A set of the feedback metadata the prompt specified in [GenerateContentRequest.contents](GenerateContentRequest). #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] @@ -972,7 +1139,7 @@ pub mod gemini_api_types { ProhibitedContent, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum FinishReason { /// Default value. This value is unused. @@ -999,13 +1166,13 @@ pub mod gemini_api_types { MalformedFunctionCall, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { #[serde(skip_serializing_if = "Option::is_none")] @@ -1018,19 +1185,19 @@ pub mod gemini_api_types { pub license: Option, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct LogprobsResult { pub top_candidate: Vec, pub chosen_candidate: Vec, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] pub struct TopCandidate { pub candidates: Vec, } - #[derive(Debug, Deserialize, Serialize)] + #[derive(Clone, Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct LogProbCandidate { pub token: String, @@ -1381,14 +1548,43 @@ pub mod gemini_api_types { pub parameters: Option, } - #[derive(Debug, Serialize)] + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolConfig { - pub schema: Option, + pub function_calling_config: Option, + } + + #[derive(Debug, Serialize, Deserialize, Default)] + #[serde(tag = "mode", rename_all = "UPPERCASE")] + pub enum FunctionCallingMode { + #[default] + Auto, + None, + Any { + #[serde(skip_serializing_if = "Option::is_none")] + allowed_function_names: Option>, + }, + } + + impl TryFrom for FunctionCallingMode { + type Error = CompletionError; + fn try_from(value: message::ToolChoice) -> Result { + let res = match value { + message::ToolChoice::Auto => Self::Auto, + message::ToolChoice::None => Self::None, + message::ToolChoice::Required => Self::Any { + allowed_function_names: None, + }, + message::ToolChoice::Specific { function_names } => Self::Any { + allowed_function_names: Some(function_names), + }, + }; + + Ok(res) + } } #[derive(Debug, Serialize)] - #[serde(rename_all = "camelCase")] pub struct CodeExecution {} #[derive(Debug, Serialize)] diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index 395d96ffe..d0ba60433 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -1,7 +1,9 @@ +use crate::telemetry::SpanCombinator; use async_stream::stream; use futures::StreamExt; use reqwest_eventsource::{Event, RequestBuilderExt}; use serde::{Deserialize, Serialize}; +use tracing::info_span; use super::completion::{ CompletionModel, create_request_body, @@ -25,6 +27,20 @@ pub struct PartialUsage { pub prompt_token_count: i32, } +impl GetTokenUsage for PartialUsage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.prompt_token_count as u64; + usage.output_tokens = (self.cached_content_token_count.unwrap_or_default() + + self.candidates_token_count.unwrap_or_default() + + self.thoughts_token_count.unwrap_or_default()) as u64; + usage.total_tokens = usage.input_tokens + usage.output_tokens; + + Some(usage) + } +} + #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] pub struct StreamGenerateContentResponse { @@ -59,8 +75,28 @@ impl CompletionModel { completion_request: CompletionRequest, ) -> Result, CompletionError> { + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "gcp.gemini", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; let request = create_request_body(completion_request)?; + span.record_model_input(&request.contents); + tracing::debug!( "Sending completion request to Gemini API {}", serde_json::to_string_pretty(&request)? @@ -78,6 +114,8 @@ impl CompletionModel { .expect("Cloning request must always succeed"); let stream = Box::pin(stream! { + let mut text_response = String::new(); + let mut model_outputs: Vec = Vec::new(); while let Some(event_result) = event_source.next().await { match event_result { Ok(Event::Open) => { @@ -116,12 +154,14 @@ impl CompletionModel { part: PartKind::Text(text), .. }) => { + text_response += text; yield Ok(streaming::RawStreamingChoice::Message(text.clone())); }, Some(Part { part: PartKind::FunctionCall(function_call), .. }) => { + model_outputs.push(choice.content.parts.first().cloned().expect("This should never fail")); yield Ok(streaming::RawStreamingChoice::ToolCall { name: function_call.name.clone(), id: function_call.name.clone(), @@ -137,6 +177,12 @@ impl CompletionModel { // Check if this is the final response if choice.finish_reason.is_some() { + if !text_response.is_empty() { + model_outputs.push(Part { thought: None, thought_signature: None, part: PartKind::Text(text_response), additional_params: None }); + } + let span = tracing::Span::current(); + span.record_model_output(&model_outputs); + span.record_token_usage(&data.usage_metadata); yield Ok(streaming::RawStreamingChoice::FinalResponse(StreamingCompletionResponse { usage_metadata: data.usage_metadata.unwrap_or_default() })); diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 93de12759..5e61ffe5a 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -10,12 +10,15 @@ //! ``` use reqwest_eventsource::{Event, RequestBuilderExt}; use std::collections::HashMap; +use tracing::info_span; +use tracing_futures::Instrument; use super::openai::{CompletionResponse, StreamingToolCall, TranscriptionResponse, Usage}; use crate::client::{CompletionClient, TranscriptionClient, VerifyClient, VerifyError}; use crate::completion::GetTokenUsage; use crate::http_client::{self, HttpClientExt}; use crate::json_utils::merge; +use crate::providers::openai::{AssistantContent, Function, ToolType}; use async_stream::stream; use futures::StreamExt; @@ -432,6 +435,11 @@ impl CompletionModel { .collect::, _>>()?, ); + let tool_choice = completion_request + .tool_choice + .map(crate::providers::openai::ToolChoice::try_from) + .transpose()?; + let request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -444,7 +452,7 @@ impl CompletionModel { "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, "reasoning_format": "parsed" }) }; @@ -468,38 +476,73 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(completion_request)?; + let preamble = completion_request.preamble.clone(); - let response = self - .client - .reqwest_post("/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + let request = self.create_completion_request(completion_request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "groq", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - if response.status().is_success() { - match response - .json::>() + let async_block = async move { + let response = self + .client + .reqwest_post("/chat/completions") + .json(&request) + .send() .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? - { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "groq completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - response.try_into() + .map_err(|e| http_client::Error::Instance(e.into()))?; + + if response.status().is_success() { + match response + .json::>() + .await + .map_err(|e| http_client::Error::Instance(e.into()))? + { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record("gen_ai.response.id", response.id.clone()); + span.record("gen_ai.response.model_name", response.model.clone()); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&response.choices).unwrap(), + ); + if let Some(ref usage) = response.usage { + span.record("gen_ai.usage.input_tokens", usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + usage.total_tokens - usage.prompt_tokens, + ); + } + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) - } + }; + + tracing::Instrument::instrument(async_block, span).await } #[cfg_attr(feature = "worker", worker::send)] @@ -510,6 +553,7 @@ impl completion::CompletionModel for CompletionModel { crate::streaming::StreamingCompletionResponse, CompletionError, > { + let preamble = request.preamble.clone(); let mut request = self.create_completion_request(request)?; request = merge( @@ -519,7 +563,26 @@ impl completion::CompletionModel for CompletionModel { let builder = self.client.reqwest_post("/chat/completions").json(&request); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "groq", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await } } @@ -661,16 +724,20 @@ pub async fn send_compatible_streaming_request( crate::streaming::StreamingCompletionResponse, CompletionError, > { + let span = tracing::Span::current(); let mut event_source = request_builder .eventsource() .expect("Cloning request must succeed"); let stream = Box::pin(stream! { + let span = tracing::Span::current(); let mut final_usage = Usage { prompt_tokens: 0, total_tokens: 0 }; + let mut text_response = String::new(); + let mut calls: HashMap = HashMap::new(); while let Some(event_result) = event_source.next().await { @@ -745,6 +812,7 @@ pub async fn send_compatible_streaming_request( // Streamed content if let Some(content) = content { + text_response += content; yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone())); } } @@ -766,11 +834,21 @@ pub async fn send_compatible_streaming_request( } } + let mut tool_calls = Vec::new(); // Flush accumulated tool calls for (_, (id, name, arguments)) in calls { let Ok(arguments_json) = serde_json::from_str::(&arguments) else { continue; }; + + tool_calls.push(rig::providers::openai::completion::ToolCall { + id: id.clone(), + r#type: ToolType::Function, + function: Function { + name: name.clone(), + arguments: arguments_json.clone() + } + }); yield Ok(crate::streaming::RawStreamingChoice::ToolCall { id, name, @@ -779,11 +857,23 @@ pub async fn send_compatible_streaming_request( }); } + let response_message = crate::providers::openai::completion::Message::Assistant { + content: vec![AssistantContent::Text { text: text_response }], + refusal: None, + audio: None, + name: None, + tool_calls + }; + + span.record("gen_ai.output.messages", serde_json::to_string(&vec![response_message]).unwrap()); + span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens); + span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens); + // Final response yield Ok(crate::streaming::RawStreamingChoice::FinalResponse( StreamingCompletionResponse { usage: final_usage.clone() } )); - }); + }.instrument(span)); Ok(crate::streaming::StreamingCompletionResponse::stream( stream, diff --git a/rig-core/src/providers/huggingface/completion.rs b/rig-core/src/providers/huggingface/completion.rs index f94a5d818..6edb84134 100644 --- a/rig-core/src/providers/huggingface/completion.rs +++ b/rig-core/src/providers/huggingface/completion.rs @@ -1,9 +1,7 @@ -use serde::{Deserialize, Deserializer, Serialize}; -use serde_json::{Value, json}; -use std::{convert::Infallible, str::FromStr}; - use super::client::Client; +use crate::completion::GetTokenUsage; use crate::providers::openai::StreamingCompletionResponse; +use crate::telemetry::SpanCombinator; use crate::{ OneOrMany, completion::{self, CompletionError, CompletionRequest}, @@ -11,6 +9,10 @@ use crate::{ message::{self}, one_or_many::string_or_one_or_many, }; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::{Value, json}; +use std::{convert::Infallible, str::FromStr}; +use tracing::info_span; #[derive(Debug, Deserialize)] #[serde(untagged)] @@ -201,6 +203,19 @@ impl TryFrom for UserContent { fn try_from(content: message::UserContent) -> Result { match content { message::UserContent::Text(text) => Ok(UserContent::Text { text: text.text }), + message::UserContent::Document(message::Document { + data: message::DocumentSourceKind::Raw(raw), + .. + }) => { + let text = String::from_utf8_lossy(raw.as_slice()).into(); + Ok(UserContent::Text { text }) + } + message::UserContent::Document(message::Document { + data: + message::DocumentSourceKind::Base64(text) + | message::DocumentSourceKind::String(text), + .. + }) => Ok(UserContent::Text { text }), message::UserContent::Image(message::Image { data, .. }) => match data { message::DocumentSourceKind::Url(url) => Ok(UserContent::ImageUrl { image_url: ImageUrl { url }, @@ -303,6 +318,17 @@ impl TryFrom for Vec { image_url: ImageUrl { url }, }) } + message::UserContent::Document(message::Document { + data: message::DocumentSourceKind::Raw(raw), .. + }) => { + let text = String::from_utf8_lossy(raw.as_slice()).into(); + Ok(UserContent::Text { text }) + } + message::UserContent::Document(message::Document { + data: message::DocumentSourceKind::Base64(text) | message::DocumentSourceKind::String(text), .. + }) => { + Ok(UserContent::Text { text }) + } _ => Err(message::MessageError::ConversionError( "Huggingface inputs only support text and image URLs (both base64-encoded images and regular URLs)".into(), )), @@ -398,7 +424,7 @@ impl TryFrom for message::Message { } } -#[derive(Debug, Deserialize, Serialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct Choice { pub finish_reason: String, pub index: usize, @@ -414,7 +440,18 @@ pub struct Usage { pub total_tokens: i32, } -#[derive(Debug, Deserialize, Serialize)] +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + usage.input_tokens = self.prompt_tokens as u64; + usage.output_tokens = self.completion_tokens as u64; + usage.total_tokens = self.total_tokens as u64; + + Some(usage) + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct CompletionResponse { pub created: i32, pub id: String, @@ -425,6 +462,63 @@ pub struct CompletionResponse { pub usage: Usage, } +impl crate::telemetry::ProviderResponseExt for CompletionResponse { + type OutputMessage = Choice; + type Usage = Usage; + + fn get_response_id(&self) -> Option { + Some(self.id.clone()) + } + + fn get_response_model_name(&self) -> Option { + Some(self.model.clone()) + } + + fn get_output_messages(&self) -> Vec { + self.choices.clone() + } + + fn get_text_response(&self) -> Option { + let text_response = self + .choices + .iter() + .filter_map(|x| { + let Message::User { ref content } = x.message else { + return None; + }; + + let text = content + .iter() + .filter_map(|x| { + if let UserContent::Text { text } = x { + Some(text.clone()) + } else { + None + } + }) + .collect::>(); + + if text.is_empty() { + None + } else { + Some(text.join("\n")) + } + }) + .collect::>() + .join("\n"); + + if text_response.is_empty() { + None + } else { + Some(text_response) + } + } + + fn get_usage(&self) -> Option { + Some(self.usage.clone()) + } +} + fn default_string_on_null<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -537,6 +631,12 @@ impl CompletionModel { let model = self.client.sub_provider.model_identifier(&self.model); + let tool_choice = completion_request + .tool_choice + .clone() + .map(crate::providers::openai::completion::ToolChoice::try_from) + .transpose()?; + let request = if completion_request.tools.is_empty() { json!({ "model": model, @@ -549,7 +649,7 @@ impl CompletionModel { "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; Ok(request) @@ -565,7 +665,26 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "huggingface", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; let request = self.create_request_body(&completion_request)?; + span.record_model_input(&request.get("messages")); let path = self.client.sub_provider.completion_endpoint(&self.model); @@ -593,10 +712,11 @@ impl completion::CompletionModel for CompletionModel { match serde_json::from_slice::>(&bytes)? { ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "Huggingface completion token usage: {:?}", - format!("{:?}", response.usage) - ); + let span = tracing::Span::current(); + span.record_token_usage(&response.usage); + span.record_model_output(&response.choices); + span.record_response_metadata(&response); + response.try_into() } ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.to_string())), diff --git a/rig-core/src/providers/huggingface/streaming.rs b/rig-core/src/providers/huggingface/streaming.rs index 35ce1764c..50038835d 100644 --- a/rig-core/src/providers/huggingface/streaming.rs +++ b/rig-core/src/providers/huggingface/streaming.rs @@ -4,6 +4,7 @@ use crate::json_utils::merge_inplace; use crate::providers::openai::{StreamingCompletionResponse, send_compatible_streaming_request}; use crate::streaming; use serde_json::json; +use tracing::{Instrument, info_span}; impl CompletionModel { pub(crate) async fn stream( @@ -26,10 +27,30 @@ impl CompletionModel { // HF Inference API uses the model in the path even though its specified in the request body let path = self.client.sub_provider.completion_endpoint(&self.model); - let request = serde_json::to_vec(&request)?; + let body = serde_json::to_vec(&request)?; - let builder = self.client.post_reqwest(&path).body(request); + let builder = self.client.post_reqwest(&path).body(body); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "huggingface", + gen_ai.request.model = self.model, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request["messages"]).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + send_compatible_streaming_request(builder) + .instrument(span) + .await } } diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index bbf6cc2fd..8b15776c4 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -393,6 +393,9 @@ impl CompletionModel { &self, completion_request: CompletionRequest, ) -> Result { + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic"); + } // Build up the order of messages (context, chat_history, prompt) let mut partial_history = vec![]; if let Some(docs) = completion_request.normalized_documents() { @@ -441,39 +444,64 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; - let response = self - .client - .reqwest_post("/v1/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "hyperbolic", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - if response.status().is_success() { - match response - .json::>() + let async_block = async move { + let response = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request) + .send() .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? - { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "Hyperbolic completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - - response.try_into() + .map_err(|e| http_client::Error::Instance(e.into()))?; + + if response.status().is_success() { + match response + .json::>() + .await + .map_err(|e| http_client::Error::Instance(e.into()))? + { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Hyperbolic completion token usage: {:?}", + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) + ); + + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) - } + }; + + async_block.instrument(span).await } #[cfg_attr(feature = "worker", worker::send)] @@ -481,8 +509,28 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let mut request = self.create_completion_request(completion_request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "hyperbolic", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + merge_inplace( &mut request, json!({"stream": true, "stream_options": {"include_usage": true}}), @@ -493,7 +541,9 @@ impl completion::CompletionModel for CompletionModel { .reqwest_post("/v1/chat/completions") .json(&request); - send_compatible_streaming_request(builder).await + send_compatible_streaming_request(builder) + .instrument(span) + .await } } @@ -644,6 +694,7 @@ mod image_generation { // ====================================== #[cfg(feature = "audio")] pub use audio_generation::*; +use tracing::{Instrument, info_span}; #[cfg(feature = "audio")] #[cfg_attr(docsrs, doc(cfg(feature = "image")))] diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 7ec38dd00..08640d9ad 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -25,7 +25,7 @@ use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::string::FromUtf8Error; use thiserror::Error; -use tracing; +use tracing::{self, Instrument, info_span}; #[derive(Debug, Error)] pub enum MiraError { @@ -355,6 +355,10 @@ impl CompletionModel { &self, completion_request: CompletionRequest, ) -> Result { + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); + } + let mut messages = Vec::new(); // Add preamble as user message if available @@ -371,7 +375,7 @@ impl CompletionModel { .into_iter() .filter_map(|doc| match doc { UserContent::Document(Document { - data: DocumentSourceKind::Base64(data), + data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data), .. }) => Some(data), UserContent::Text(text) => Some(text.text), @@ -442,36 +446,85 @@ impl completion::CompletionModel for CompletionModel { completion_request: CompletionRequest, ) -> Result, CompletionError> { if !completion_request.tools.is_empty() { - tracing::warn!(target: "rig", - "Tool calls are not supported by the Mira provider. {} tools will be ignored.", - completion_request.tools.len() + tracing::warn!(target: "rig::completions", + "Tool calls are not supported by the Mira provider. {len} tools will be ignored.", + len = completion_request.tools.len() ); } - let mira_request = self.create_completion_request(completion_request)?; + let preamble = completion_request.preamble.clone(); + + let request = self.create_completion_request(completion_request)?; + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "mira", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - let response = self - .client - .reqwest_post("/v1/chat/completions") - .json(&mira_request) - .send() - .await - .map_err(|e| CompletionError::ProviderError(e.to_string()))?; - - if !response.status().is_success() { - let status = response.status().as_u16(); - let error_text = response.text().await.unwrap_or_default(); - return Err(CompletionError::ProviderError(format!( - "API error: {status} - {error_text}" - ))); - } + let async_block = async move { + let response = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request) + .send() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + + if !response.status().is_success() { + let status = response.status().as_u16(); + let error_text = response.text().await.unwrap_or_default(); + return Err(CompletionError::ProviderError(format!( + "API error: {status} - {error_text}" + ))); + } - let response: CompletionResponse = response - .json() - .await - .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + let response: CompletionResponse = response + .json() + .await + .map_err(|e| CompletionError::ProviderError(e.to_string()))?; + + if let CompletionResponse::Structured { + id, + model, + choices, + usage, + .. + } = &response + { + let span = tracing::Span::current(); + span.record("gen_ai.response.model_name", model); + span.record("gen_ai.response.id", id); + span.record( + "gen_ai.output.messages", + serde_json::to_string(choices).unwrap(), + ); + if let Some(usage) = usage { + span.record("gen_ai.usage.input_tokens", usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + usage.total_tokens - usage.prompt_tokens, + ); + } + } - response.try_into() + response.try_into() + }; + + async_block.instrument(span).await } #[cfg_attr(feature = "worker", worker::send)] @@ -479,8 +532,27 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let mut request = self.create_completion_request(completion_request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "mira", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; request = merge(request, json!({"stream": true})); let builder = self @@ -488,7 +560,9 @@ impl completion::CompletionModel for CompletionModel { .reqwest_post("/v1/chat/completions") .json(&request); - send_compatible_streaming_request(builder).await + send_compatible_streaming_request(builder) + .instrument(span) + .await } } diff --git a/rig-core/src/providers/mistral/completion.rs b/rig-core/src/providers/mistral/completion.rs index 5a5e5820d..4b1c5c534 100644 --- a/rig-core/src/providers/mistral/completion.rs +++ b/rig-core/src/providers/mistral/completion.rs @@ -2,6 +2,7 @@ use async_stream::stream; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::{convert::Infallible, str::FromStr}; +use tracing::{Instrument, info_span}; use super::client::{Client, Usage}; use crate::completion::GetTokenUsage; @@ -12,6 +13,7 @@ use crate::{ completion::{self, CompletionError, CompletionRequest}, json_utils, message, providers::mistral::client::ApiResponse, + telemetry::SpanCombinator, }; pub const CODESTRAL: &str = "codestral-latest"; @@ -256,6 +258,33 @@ pub struct CompletionModel { pub model: String, } +#[derive(Debug, Default, Serialize, Deserialize)] +pub enum ToolChoice { + #[default] + Auto, + None, + Any, +} + +impl TryFrom for ToolChoice { + type Error = CompletionError; + + fn try_from(value: message::ToolChoice) -> Result { + let res = match value { + message::ToolChoice::Auto => Self::Auto, + message::ToolChoice::None => Self::None, + message::ToolChoice::Required => Self::Any, + message::ToolChoice::Specific { .. } => { + return Err(CompletionError::ProviderError( + "Mistral doesn't support requiring specific tools to be called".to_string(), + )); + } + }; + + Ok(res) + } +} + impl CompletionModel { pub fn new(client: Client, model: &str) -> Self { Self { @@ -290,6 +319,11 @@ impl CompletionModel { .collect::>(), ); + let tool_choice = completion_request + .tool_choice + .map(ToolChoice::try_from) + .transpose()?; + let request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -301,7 +335,7 @@ impl CompletionModel { "model": self.model, "messages": full_history, "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; @@ -337,6 +371,47 @@ pub struct CompletionResponse { pub usage: Option, } +impl crate::telemetry::ProviderResponseExt for CompletionResponse { + type OutputMessage = Choice; + type Usage = Usage; + + fn get_response_id(&self) -> Option { + Some(self.id.clone()) + } + + fn get_response_model_name(&self) -> Option { + Some(self.model.clone()) + } + + fn get_output_messages(&self) -> Vec { + self.choices.clone() + } + + fn get_text_response(&self) -> Option { + let res = self + .choices + .iter() + .filter_map(|choice| match choice.message { + Message::Assistant { ref content, .. } => { + if content.is_empty() { + None + } else { + Some(content.to_string()) + } + } + _ => None, + }) + .collect::>() + .join("\n"); + + if res.is_empty() { None } else { Some(res) } + } + + fn get_usage(&self) -> Option { + self.usage.clone() + } +} + impl GetTokenUsage for CompletionResponse { fn token_usage(&self) -> Option { let api_usage = self.usage.clone()?; @@ -424,8 +499,28 @@ where &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let body = self.create_completion_request(completion_request)?; - let body = serde_json::to_vec(&body)?; + let preamble = completion_request.preamble.clone(); + let request = self.create_completion_request(completion_request)?; + let body = serde_json::to_vec(&request)?; + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "mistral", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; let request = self .client @@ -433,25 +528,28 @@ where .body(body) .map_err(|e| CompletionError::HttpError(e.into()))?; - let response = self.client.send(request).await?; - - if response.status().is_success() { - let text = http_client::text(response).await?; - - match serde_json::from_str::>(&text)? { - ApiResponse::Ok(response) => { - tracing::debug!(target: "rig", - "Mistral completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - response.try_into() + async move { + let response = self.client.send(request).await?; + + if response.status().is_success() { + let text = http_client::text(response).await?; + match serde_json::from_str::>(&text)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record_token_usage(&response); + span.record_model_output(&response.choices); + span.record_response_metadata(&response); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) } - } else { - let text = http_client::text(response).await?; - Err(CompletionError::ProviderError(text)) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index b8c5d6196..fc7cdea31 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -19,8 +19,9 @@ use crate::{ providers::openai, }; use crate::{http_client, impl_conversion_traits, message}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use tracing::{Instrument, info_span}; // ================================================================ // Main Moonshot Client @@ -280,6 +281,11 @@ impl CompletionModel { .collect::>(), ); + let tool_choice = completion_request + .tool_choice + .map(ToolChoice::try_from) + .transpose()?; + let request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -292,7 +298,7 @@ impl CompletionModel { "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; @@ -315,40 +321,75 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; - let response = self - .client - .reqwest_post("/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "moonshot", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - if response.status().is_success() { - let t = response - .text() + let async_block = async move { + let response = self + .client + .reqwest_post("/chat/completions") + .json(&request) + .send() .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - tracing::debug!(target: "rig", "MoonShot completion error: {}", t); - - match serde_json::from_str::>(&t)? { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "MoonShot completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - response.try_into() + .map_err(|e| http_client::Error::Instance(e.into()))?; + + if response.status().is_success() { + let t = response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?; + tracing::debug!(target: "rig::completions", "MoonShot completion response: {t}"); + + match serde_json::from_str::>(&t)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record("gen_ai.response.id", response.id.clone()); + span.record("gen_ai.response.model_name", response.model.clone()); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&response.choices).unwrap(), + ); + if let Some(ref usage) = response.usage { + span.record("gen_ai.usage.input_tokens", usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + usage.total_tokens - usage.prompt_tokens, + ); + } + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.error.message)), + } else { + Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) - } + }; + + async_block.instrument(span).await } #[cfg_attr(feature = "worker", worker::send)] @@ -356,8 +397,28 @@ impl completion::CompletionModel for CompletionModel { &self, request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = request.preamble.clone(); let mut request = self.create_completion_request(request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "moonshot", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + request = merge( request, json!({"stream": true, "stream_options": {"include_usage": true}}), @@ -365,6 +426,33 @@ impl completion::CompletionModel for CompletionModel { let builder = self.client.reqwest_post("/chat/completions").json(&request); - send_compatible_streaming_request(builder).await + send_compatible_streaming_request(builder) + .instrument(span) + .await + } +} + +#[derive(Default, Debug, Deserialize, Serialize)] +pub enum ToolChoice { + None, + #[default] + Auto, +} + +impl TryFrom for ToolChoice { + type Error = CompletionError; + + fn try_from(value: message::ToolChoice) -> Result { + let res = match value { + message::ToolChoice::None => Self::None, + message::ToolChoice::Auto => Self::Auto, + choice => { + return Err(CompletionError::ProviderError(format!( + "Unsupported tool choice type: {choice:?}" + ))); + } + }; + + Ok(res) } } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 8921382dd..b1bf268b7 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -61,6 +61,9 @@ use reqwest; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::{convert::TryFrom, str::FromStr}; +use tracing::info_span; +use tracing_futures::Instrument; +use url::Url; // ---------- Main Client ---------- const OLLAMA_API_BASE_URL: &str = "http://localhost:11434"; @@ -469,6 +472,10 @@ impl CompletionModel { &self, completion_request: CompletionRequest, ) -> Result { + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported for Ollama"); + } + // Build up the order of messages (context, chat_history) let mut partial_history = vec![]; if let Some(docs) = completion_request.normalized_documents() { @@ -559,36 +566,76 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request_payload = self.create_completion_request(completion_request)?; + let preamble = completion_request.preamble.clone(); + let request = self.create_completion_request(completion_request)?; + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "ollama", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - let response = self - .client - .reqwest_post("api/chat") - .json(&request_payload) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + let async_block = async move { + let response = self + .client + .reqwest_post("api/chat") + .json(&request) + .send() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?; + + if !response.status().is_success() { + return Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )); + } - if !response.status().is_success() { - return Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )); - } + let bytes = response + .bytes() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?; - let bytes = response - .bytes() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + tracing::debug!(target: "rig", "Received response from Ollama: {}", String::from_utf8_lossy(&bytes)); - tracing::debug!(target: "rig", "Received response from Ollama: {}", String::from_utf8_lossy(&bytes)); + let response: CompletionResponse = serde_json::from_slice(&bytes)?; + let span = tracing::Span::current(); + span.record("gen_ai.response.model_name", &response.model); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&vec![&response.message]).unwrap(), + ); + span.record( + "gen_ai.usage.input_tokens", + response.prompt_eval_count.unwrap_or_default(), + ); + span.record( + "gen_ai.usage.output_tokens", + response.eval_count.unwrap_or_default(), + ); - let chat_resp: CompletionResponse = serde_json::from_slice(&bytes)?; + let response: completion::CompletionResponse = + response.try_into()?; - let conv: completion::CompletionResponse = chat_resp.try_into()?; + Ok(response) + }; - Ok(conv) + tracing::Instrument::instrument(async_block, span).await } #[cfg_attr(feature = "worker", worker::send)] @@ -597,32 +644,54 @@ impl completion::CompletionModel for CompletionModel { request: CompletionRequest, ) -> Result, CompletionError> { - let mut request_payload = self.create_completion_request(request)?; - merge_inplace(&mut request_payload, json!({"stream": true})); + let preamble = request.preamble.clone(); + let mut request = self.create_completion_request(request)?; + merge_inplace(&mut request, json!({"stream": true})); + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "ollama", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; let response = self .client .reqwest_post("api/chat") - .json(&request_payload) + .json(&request) .send() .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + .map_err(|e| http_client::Error::Instance(e.into()))?; if !response.status().is_success() { return Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, )); } let stream = Box::pin(try_stream! { + let span = tracing::Span::current(); let mut byte_stream = response.bytes_stream(); + let mut tool_calls_final = Vec::new(); + let mut text_response = String::new(); while let Some(chunk) = byte_stream.next().await { - let bytes = chunk.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?; + let bytes = chunk.map_err(|e| http_client::Error::Instance(e.into()))?; for line in bytes.split(|&b| b == b'\n') { if line.is_empty() { @@ -634,6 +703,16 @@ impl completion::CompletionModel for CompletionModel { let response: CompletionResponse = serde_json::from_slice(line)?; if response.done { + span.record("gen_ai.usage.input_tokens", response.prompt_eval_count); + span.record("gen_ai.usage.output_tokens", response.eval_count); + let message = Message::Assistant { + content: text_response.clone(), + thinking: None, + images: None, + name: None, + tool_calls: tool_calls_final.clone() + }; + span.record("gen_ai.output.messages", serde_json::to_string(&vec![message]).unwrap()); yield RawStreamingChoice::FinalResponse( StreamingCompletionResponse { total_duration: response.total_duration, @@ -650,9 +729,11 @@ impl completion::CompletionModel for CompletionModel { if let Message::Assistant { content, tool_calls, .. } = response.message { if !content.is_empty() { + text_response += &content; yield RawStreamingChoice::Message(content); } for tool_call in tool_calls { + tool_calls_final.push(tool_call.clone()); yield RawStreamingChoice::ToolCall { id: String::new(), name: tool_call.function.name, @@ -663,7 +744,7 @@ impl completion::CompletionModel for CompletionModel { } } } - }); + }.instrument(span)); Ok(streaming::StreamingCompletionResponse::stream(stream)) } @@ -695,7 +776,6 @@ impl From for ToolDefinition { #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct ToolCall { - // pub id: String, #[serde(default, rename = "type")] pub r#type: ToolType, pub function: Function, @@ -807,7 +887,9 @@ impl TryFrom for Vec { }) => images.push(data), crate::message::UserContent::Document( crate::message::Document { - data: DocumentSourceKind::Base64(data), + data: + DocumentSourceKind::Base64(data) + | DocumentSourceKind::String(data), .. }, ) => texts.push(data), diff --git a/rig-core/src/providers/openai/completion/mod.rs b/rig-core/src/providers/openai/completion/mod.rs index fa91da95a..f9d028957 100644 --- a/rig-core/src/providers/openai/completion/mod.rs +++ b/rig-core/src/providers/openai/completion/mod.rs @@ -3,15 +3,18 @@ // ================================================================ use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletionResponse}; -use crate::completion::{CompletionError, CompletionRequest}; +use crate::completion::{ + CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage, +}; use crate::http_client::HttpClientExt; use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType}; use crate::one_or_many::string_or_one_or_many; +use crate::telemetry::{ProviderResponseExt, SpanCombinator}; use crate::{OneOrMany, completion, http_client, json_utils, message}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use std::convert::Infallible; use std::fmt; +use tracing::{Instrument, info_span}; use std::str::FromStr; @@ -271,6 +274,33 @@ impl From for ToolDefinition { } } +#[derive(Default, Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoice { + #[default] + Auto, + None, + Required, +} + +impl TryFrom for ToolChoice { + type Error = CompletionError; + fn try_from(value: crate::message::ToolChoice) -> Result { + let res = match value { + message::ToolChoice::Specific { .. } => { + return Err(CompletionError::ProviderError( + "Provider doesn't support only using specific tools".to_string(), + )); + } + message::ToolChoice::Auto => Self::Auto, + message::ToolChoice::None => Self::None, + message::ToolChoice::Required => Self::Required, + }; + + Ok(res) + } +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] pub struct Function { pub name: String, @@ -351,11 +381,19 @@ impl TryFrom for Vec { image_url: ImageUrl { url, detail }, }) } + DocumentSourceKind::Raw(_) => { + Err(message::MessageError::ConversionError( + "Raw files not supported, encode as base64 first".into(), + )) + } DocumentSourceKind::Unknown => { Err(message::MessageError::ConversionError( "Document has no body".into(), )) } + doc => Err(message::MessageError::ConversionError(format!( + "Unsupported document type: {doc:?}" + ))), }, message::UserContent::Document(message::Document { data, .. }) => { if let DocumentSourceKind::Base64(text) = data { @@ -669,7 +707,40 @@ impl TryFrom for completion::CompletionResponse Option { + Some(self.id.to_owned()) + } + + fn get_response_model_name(&self) -> Option { + Some(self.model.to_owned()) + } + + fn get_output_messages(&self) -> Vec { + self.choices.clone() + } + + fn get_text_response(&self) -> Option { + let Message::User { ref content, .. } = self.choices.last()?.message.clone() else { + return None; + }; + + let UserContent::Text { text } = content.first() else { + return None; + }; + + Some(text) + } + + fn get_usage(&self) -> Option { + self.usage.clone() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Choice { pub index: usize, pub message: Message, @@ -711,6 +782,17 @@ impl fmt::Display for Usage { } } +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + usage.input_tokens = self.prompt_tokens as u64; + usage.output_tokens = (self.total_tokens - self.prompt_tokens) as u64; + usage.total_tokens = self.total_tokens as u64; + + Some(usage) + } +} + #[derive(Clone)] pub struct CompletionModel { pub(crate) client: Client, @@ -728,22 +810,41 @@ where model: model.to_string(), } } +} - pub(crate) fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - // Build up the order of messages (context, chat_history) +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct CompletionRequest { + model: String, + messages: Vec, + tools: Vec, + tool_choice: Option, + temperature: Option, + #[serde(flatten)] + additional_params: Option, +} + +impl TryFrom<(String, CoreCompletionRequest)> for CompletionRequest { + type Error = CompletionError; + + fn try_from((model, req): (String, CoreCompletionRequest)) -> Result { let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); + let CoreCompletionRequest { + preamble, + chat_history, + tools, + temperature, + additional_params, + tool_choice, + .. + } = req; - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = completion_request - .preamble - .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); + partial_history.extend(chat_history); + + let mut full_history: Vec = + preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); // Convert and extend the rest of the history full_history.extend( @@ -756,41 +857,59 @@ where .collect::>(), ); - let request = if completion_request.tools.is_empty() { - serde_json::json!({ - "model": self.model, - "messages": full_history, + let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", - }) + let res = Self { + model, + messages: full_history, + tools: tools + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + temperature, + additional_params, }; - // only include temperature if it exists - // because some models don't support temperature - let request = if let Some(temperature) = completion_request.temperature { - json_utils::merge( - request, - json!({ - "temperature": temperature, - }), - ) - } else { - request + Ok(res) + } +} + +impl crate::telemetry::ProviderRequestExt for CompletionRequest { + type InputMessage = Message; + + fn get_input_messages(&self) -> Vec { + self.messages.clone() + } + + fn get_system_prompt(&self) -> Option { + let first_message = self.messages.first()?; + + let Message::System { ref content, .. } = first_message.clone() else { + return None; }; - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request + let SystemContent { text, .. } = content.first(); + + Some(text) + } + + fn get_prompt(&self) -> Option { + let last_message = self.messages.last()?; + + let Message::User { ref content, .. } = last_message.clone() else { + return None; }; - Ok(request) + let UserContent::Text { text } = content.first() else { + return None; + }; + + Some(text) + } + + fn get_model_name(&self) -> String { + self.model.clone() } } @@ -807,14 +926,29 @@ impl completion::CompletionModel for CompletionModel { #[cfg_attr(feature = "worker", worker::send)] async fn completion( &self, - completion_request: CompletionRequest, + completion_request: CoreCompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(completion_request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "openai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - tracing::debug!( - "OpenAI request: {request}", - request = serde_json::to_string_pretty(&request).unwrap() - ); + let request = CompletionRequest::try_from((self.model.to_owned(), completion_request))?; + span.record_model_input(&request.messages); let body = serde_json::to_vec(&request)?; @@ -824,33 +958,36 @@ impl completion::CompletionModel for CompletionModel { .body(body) .map_err(|e| CompletionError::HttpError(e.into()))?; - let response = self.client.send(req).await?; - - if response.status().is_success() { - let text = http_client::text(response).await?; - - tracing::debug!(target: "rig", "OpenAI completion error: {}", text); - - match serde_json::from_str::>(&text)? { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "OpenAI completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{}", usage.total_tokens)).unwrap_or("N/A".to_string()) - ); - response.try_into() + async move { + let response = self.client.send(req).await?; + + if response.status().is_success() { + let text = http_client::text(response).await?; + + match serde_json::from_str::>(&text)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record_model_output(&response.choices); + span.record_response_metadata(&response); + span.record_token_usage(&response.usage); + tracing::debug!("OpenAI response: {response:?}"); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) } - } else { - let text = http_client::text(response).await?; - Err(CompletionError::ProviderError(text)) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] async fn stream( &self, - request: CompletionRequest, + request: CoreCompletionRequest, ) -> Result< crate::streaming::StreamingCompletionResponse, CompletionError, diff --git a/rig-core/src/providers/openai/completion/streaming.rs b/rig-core/src/providers/openai/completion/streaming.rs index 94313002e..a5f3473ef 100644 --- a/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig-core/src/providers/openai/completion/streaming.rs @@ -12,7 +12,8 @@ use reqwest_eventsource::RequestBuilderExt; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; -use tracing::debug; +use tracing::{debug, info_span}; +use tracing_futures::Instrument; // ================================================================ // OpenAI Completion Streaming API @@ -72,31 +73,61 @@ impl CompletionModel { completion_request: CompletionRequest, ) -> Result, CompletionError> { - let mut request = self.create_completion_request(completion_request)?; - request = merge( - request, + let request = super::CompletionRequest::try_from((self.model.clone(), completion_request))?; + let request_messages = serde_json::to_string(&request.messages) + .expect("Converting to JSON from a Rust struct shouldn't fail"); + let mut request_as_json = serde_json::to_value(request).expect("this should never fail"); + + request_as_json = merge( + request_as_json, json!({"stream": true, "stream_options": {"include_usage": true}}), ); - let builder = self.client.post_reqwest("/chat/completions").json(&request); - send_compatible_streaming_request(builder).await + let builder = self + .client + .post_reqwest("/chat/completions") + .json(&request_as_json); + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "openai", + gen_ai.request.model = self.model, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = self.model, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = request_messages, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing::Instrument::instrument(send_compatible_streaming_request(builder), span).await } } pub async fn send_compatible_streaming_request( request_builder: RequestBuilder, ) -> Result, CompletionError> { + let span = tracing::Span::current(); // Build the request with proper headers for SSE let mut event_source = request_builder .eventsource() .expect("Cloning request must always succeed"); let stream = Box::pin(stream! { + let span = tracing::Span::current(); let mut final_usage = Usage::new(); // Track in-progress tool calls let mut tool_calls: HashMap = HashMap::new(); + let mut text_content = String::new(); + while let Some(event_result) = event_source.next().await { match event_result { Ok(Event::Open) => { @@ -172,6 +203,7 @@ pub async fn send_compatible_streaming_request( // Message content if let Some(content) = &choice.delta.content { + text_content += content; yield Ok(streaming::RawStreamingChoice::Message(content.clone())) } } @@ -195,12 +227,22 @@ pub async fn send_compatible_streaming_request( // Ensure event source is closed when stream ends event_source.close(); + let mut vec_toolcalls = vec![]; + // Flush any tool calls that weren’t fully yielded for (_, (id, name, arguments)) in tool_calls { - let Ok(arguments) = serde_json::from_str(&arguments) else { + let Ok(arguments) = serde_json::from_str::(&arguments) else { continue; }; + vec_toolcalls.push(super::ToolCall { + r#type: super::ToolType::Function, + id: id.clone(), + function: super::Function { + name: name.clone(), arguments: arguments.clone() + }, + }); + yield Ok(RawStreamingChoice::ToolCall { id, name, @@ -209,10 +251,22 @@ pub async fn send_compatible_streaming_request( }); } + let message_output = super::Message::Assistant { + content: vec![super::AssistantContent::Text { text: text_content }], + refusal: None, + audio: None, + name: None, + tool_calls: vec_toolcalls + }; + + span.record("gen_ai.usage.input_tokens", final_usage.prompt_tokens); + span.record("gen_ai.usage.output_tokens", final_usage.total_tokens - final_usage.prompt_tokens); + span.record("gen_ai.output.messages", serde_json::to_string(&vec![message_output]).expect("Converting from a Rust struct should always convert to JSON without failing")); + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { usage: final_usage.clone() })); - }); + }.instrument(span)); Ok(streaming::StreamingCompletionResponse::stream(stream)) } diff --git a/rig-core/src/providers/openai/responses_api/mod.rs b/rig-core/src/providers/openai/responses_api/mod.rs index fc60a986f..56fc537a8 100644 --- a/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig-core/src/providers/openai/responses_api/mod.rs @@ -7,17 +7,23 @@ //! let openai_client = rig::providers::openai::Client::from_env(); //! let model = openai_client.completion_model("gpt-4o").completions_api(); //! ``` +use super::completion::ToolChoice; use super::{Client, responses_api::streaming::StreamingCompletionResponse}; -use super::{ImageUrl, InputAudio, SystemContent}; +use super::{InputAudio, SystemContent}; use crate::completion::CompletionError; +use crate::http_client; use crate::http_client::HttpClientExt; -use crate::message::{AudioMediaType, Document, DocumentSourceKind, MessageError, MimeType, Text}; +use crate::json_utils; +use crate::message::{ + AudioMediaType, Document, DocumentMediaType, DocumentSourceKind, ImageDetail, MessageError, + MimeType, Text, +}; use crate::one_or_many::string_or_one_or_many; -use crate::{http_client, json_utils}; use crate::{OneOrMany, completion, message}; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; +use tracing::{Instrument, info_span}; use std::convert::Infallible; use std::ops::Add; @@ -45,8 +51,10 @@ pub struct CompletionRequest { /// The temperature. Set higher (up to a max of 1.0) for more creative responses. #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, - // TODO: Fix this before opening a PR! - // tool_choice: Option, + /// Whether the LLM should be forced to use a tool before returning a response. + /// If none provided, the default option is "auto". + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, /// The tools you want to use. Currently this is limited to functions, but will be expanded on in future. #[serde(skip_serializing_if = "Vec::is_empty")] pub tools: Vec, @@ -232,6 +240,41 @@ impl TryFrom for Vec { }); } } + crate::message::UserContent::Document(Document { + data, + media_type: Some(DocumentMediaType::PDF), + .. + }) => { + let (file_data, file_url) = match data { + DocumentSourceKind::Base64(data) => { + (Some(format!("data:application/pdf;base64,{data}")), None) + } + DocumentSourceKind::Url(url) => (None, Some(url)), + DocumentSourceKind::Raw(_) => { + return Err(CompletionError::RequestError( + "Raw file data not supported, encode as base64 first" + .into(), + )); + } + doc => { + return Err(CompletionError::RequestError( + format!("Unsupported document type: {doc}").into(), + )); + } + }; + + items.push(InputItem { + role: Some(Role::User), + input: InputContent::Message(Message::User { + content: OneOrMany::one(UserContent::InputFile { + file_data, + file_url, + filename: Some("document.pdf".to_string()), + }), + name: None, + }), + }) + } // todo: should we ensure this takes into account file size? crate::message::UserContent::Document(Document { data: DocumentSourceKind::Base64(text), @@ -243,6 +286,16 @@ impl TryFrom for Vec { name: None, }), }), + crate::message::UserContent::Document(Document { + data: DocumentSourceKind::String(text), + .. + }) => items.push(InputItem { + role: Some(Role::User), + input: InputContent::Message(Message::User { + content: OneOrMany::one(UserContent::InputText { text }), + name: None, + }), + }), crate::message::UserContent::Image(crate::message::Image { data, media_type, @@ -259,16 +312,24 @@ impl TryFrom for Vec { format!("data:{media_type};base64,{data}") } DocumentSourceKind::Url(url) => url, - DocumentSourceKind::Unknown => return Err(CompletionError::RequestError("Attempted to create an OpenAI Responses AI image input from unknown variant".into())) + DocumentSourceKind::Raw(_) => { + return Err(CompletionError::RequestError( + "Raw file data not supported, encode as base64 first" + .into(), + )); + } + doc => { + return Err(CompletionError::RequestError( + format!("Unsupported document type: {doc}").into(), + )); + } }; items.push(InputItem { role: Some(Role::User), input: InputContent::Message(Message::User { content: OneOrMany::one(UserContent::InputImage { - image_url: ImageUrl { - url, - detail: detail.unwrap_or_default(), - }, + image_url: url, + detail: detail.unwrap_or_default(), }), name: None, }), @@ -611,12 +672,15 @@ impl TryFrom<(String, crate::completion::CompletionRequest)> for CompletionReque AdditionalParameters::default() }; + let tool_choice = req.tool_choice.map(ToolChoice::try_from).transpose()?; + Ok(Self { input, model, instructions: req.preamble, max_output_tokens: req.max_tokens, stream, + tool_choice, tools: req .tools .into_iter() @@ -980,10 +1044,33 @@ impl completion::CompletionModel for ResponsesCompletionModel { &self, completion_request: crate::completion::CompletionRequest, ) -> Result, CompletionError> { - let body = self.create_completion_request(completion_request)?; - tracing::debug!("OpenAI input: {}", serde_json::to_string_pretty(&body)?); + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = tracing::field::Empty, + gen_ai.request.model = tracing::field::Empty, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - let body = serde_json::to_vec(&body)?; + span.record("gen_ai.provider.name", "openai"); + span.record("gen_ai.request.model", &self.model); + let request = self.create_completion_request(completion_request)?; + span.record( + "gen_ai.input.messages", + serde_json::to_string(&request.input) + .expect("openai request to successfully turn into a JSON value"), + ); + let body = serde_json::to_vec(&request)?; let req = self .client @@ -991,18 +1078,33 @@ impl completion::CompletionModel for ResponsesCompletionModel { .body(body) .map_err(|e| CompletionError::HttpError(e.into()))?; - let response = self.client.send(req).await?; - - if response.status().is_success() { - let text = http_client::text(response).await?; - tracing::debug!(target: "rig", "OpenAI response: {}", text); - - let response = serde_json::from_str::(&text)?; - response.try_into() - } else { - let text = http_client::text(response).await?; - Err(CompletionError::ProviderError(text)) + async move { + let response = self.client.send(req).await?; + + if response.status().is_success() { + let t = http_client::text(response).await?; + let response = serde_json::from_str::(&t)?; + let span = tracing::Span::current(); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&response.output).unwrap(), + ); + span.record("gen_ai.response.id", &response.id); + span.record("gen_ai.response.model", &response.model); + if let Some(ref usage) = response.usage { + span.record("gen_ai.usage.output_tokens", usage.output_tokens); + span.record("gen_ai.usage.input_tokens", usage.input_tokens); + } + // We need to call the event here to get the span to actually send anything + tracing::info!("API successfully called"); + response.try_into() + } else { + let text = http_client::text(response).await?; + Err(CompletionError::ProviderError(text)) + } } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] @@ -1146,7 +1248,17 @@ pub enum UserContent { text: String, }, InputImage { - image_url: ImageUrl, + image_url: String, + #[serde(default)] + detail: ImageDetail, + }, + InputFile { + #[serde(skip_serializing_if = "Option::is_none")] + file_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + file_data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + filename: Option, }, Audio { input_audio: InputAudio, @@ -1216,19 +1328,57 @@ impl TryFrom for Vec { format!("data:{media_type};base64,{data}") } DocumentSourceKind::Url(url) => url, - DocumentSourceKind::Unknown => return Err(MessageError::ConversionError("Attempted to convert unknown image type to OpenAI image input".to_string())) + DocumentSourceKind::Raw(_) => { + return Err(MessageError::ConversionError( + "Raw files not supported, encode as base64 first" + .into(), + )); + } + doc => { + return Err(MessageError::ConversionError(format!( + "Unsupported document type: {doc}" + ))); + } }; Ok(UserContent::InputImage { - image_url: ImageUrl { - url, - detail: detail.unwrap_or_default(), - }, + image_url: url, + detail: detail.unwrap_or_default(), }) } - message::UserContent::Document(message::Document { data: DocumentSourceKind::Base64(text), .. }) => { - Ok(UserContent::InputText { text }) + message::UserContent::Document(message::Document { + media_type: Some(DocumentMediaType::PDF), + data, + .. + }) => { + let (file_data, file_url) = match data { + DocumentSourceKind::Base64(data) => { + (Some(format!("data:application/pdf;base64,{data}")), None) + } + DocumentSourceKind::Url(url) => (None, Some(url)), + DocumentSourceKind::Raw(_) => { + return Err(MessageError::ConversionError( + "Raw files not supported, encode as base64 first" + .into(), + )); + } + doc => { + return Err(MessageError::ConversionError(format!( + "Unsupported document type: {doc}" + ))); + } + }; + + Ok(UserContent::InputFile { + file_url, + file_data, + filename: Some("document.pdf".into()), + }) } + message::UserContent::Document(message::Document { + data: DocumentSourceKind::Base64(text), + .. + }) => Ok(UserContent::InputText { text }), message::UserContent::Audio(message::Audio { data: DocumentSourceKind::Base64(data), media_type, @@ -1242,8 +1392,10 @@ impl TryFrom for Vec { }, }, }), - message::UserContent::Audio(_) => Err(MessageError::ConversionError("Audio must be base64 encoded data".into())), - _ => unreachable!() + message::UserContent::Audio(_) => Err(MessageError::ConversionError( + "Audio must be base64 encoded data".into(), + )), + _ => unreachable!(), }) .collect::, _>>()?; diff --git a/rig-core/src/providers/openai/responses_api/streaming.rs b/rig-core/src/providers/openai/responses_api/streaming.rs index d914f43af..d952ac674 100644 --- a/rig-core/src/providers/openai/responses_api/streaming.rs +++ b/rig-core/src/providers/openai/responses_api/streaming.rs @@ -8,11 +8,11 @@ use crate::streaming; use crate::streaming::RawStreamingChoice; use async_stream::stream; use futures::StreamExt; -use reqwest::RequestBuilder; use reqwest_eventsource::Event; use reqwest_eventsource::RequestBuilderExt; use serde::{Deserialize, Serialize}; -use tracing::debug; +use tracing::{debug, info_span}; +use tracing_futures::Instrument as _; use super::{CompletionResponse, Output}; @@ -200,116 +200,139 @@ impl ResponsesCompletionModel { let mut request = self.create_completion_request(completion_request)?; request.stream = Some(true); - tracing::debug!("Input: {}", serde_json::to_string_pretty(&request)?); - - let builder = self.client.post_reqwest("/responses").json(&request); - send_compatible_streaming_request(builder).await - } -} - -/// Send a compatible streaming request. -/// The following are assumed to already be set: -/// - The URL to send a POST request to -/// - The JSON body -pub async fn send_compatible_streaming_request( - request_builder: RequestBuilder, -) -> Result, CompletionError> { - // Build the request with proper headers for SSE - let mut event_source = request_builder - .eventsource() - .expect("Cloning request must always succeed"); - - let stream = Box::pin(stream! { - let mut final_usage = ResponsesUsage::new(); - - let mut tool_calls: Vec> = Vec::new(); - - while let Some(event_result) = event_source.next().await { - match event_result { - Ok(Event::Open) => { - tracing::trace!("SSE connection opened"); - continue; - } - Ok(Event::Message(message)) => { - // Skip heartbeat messages or empty data - if message.data.trim().is_empty() { + let request_builder = self.client.post_reqwest("/responses").json(&request); + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = tracing::field::Empty, + gen_ai.request.model = tracing::field::Empty, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + span.record("gen_ai.provider.name", "openai"); + span.record("gen_ai.request.model", &self.model); + span.record( + "gen_ai.input.messages", + serde_json::to_string(&request.input).expect("This should always work"), + ); + // Build the request with proper headers for SSE + let mut event_source = request_builder + .eventsource() + .expect("Cloning request must always succeed"); + + let stream = Box::pin(stream! { + let mut final_usage = ResponsesUsage::new(); + + let mut tool_calls: Vec> = Vec::new(); + let mut combined_text = String::new(); + let span = tracing::Span::current(); + + while let Some(event_result) = event_source.next().await { + match event_result { + Ok(Event::Open) => { + tracing::trace!("SSE connection opened"); + tracing::info!("OpenAI stream started"); continue; } + Ok(Event::Message(message)) => { + // Skip heartbeat messages or empty data + if message.data.trim().is_empty() { + continue; + } - let data = serde_json::from_str::(&message.data); - - let Ok(data) = data else { - let err = data.unwrap_err(); - debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err); - continue; - }; - - if let StreamingCompletionChunk::Delta(chunk) = &data { - match &chunk.data { - ItemChunkKind::OutputItemDone(message) => { - match message { - StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => { - tracing::debug!("Function call received: {func:?}"); - tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() }); - } + let data = serde_json::from_str::(&message.data); - StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => { - let reasoning = summary - .iter() - .map(|x| { - let ReasoningSummary::SummaryText { text } = x; - text.to_owned() - }) - .collect::>() - .join("\n"); - yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()) }) + let Ok(data) = data else { + let err = data.unwrap_err(); + debug!("Couldn't serialize data as StreamingCompletionResponse: {:?}", err); + continue; + }; + + if let StreamingCompletionChunk::Delta(chunk) = &data { + match &chunk.data { + ItemChunkKind::OutputItemDone(message) => { + match message { + StreamingItemDoneOutput { item: Output::FunctionCall(func), .. } => { + tool_calls.push(streaming::RawStreamingChoice::ToolCall { id: func.id.clone(), call_id: Some(func.call_id.clone()), name: func.name.clone(), arguments: func.arguments.clone() }); + } + + StreamingItemDoneOutput { item: Output::Reasoning { summary, id }, .. } => { + let reasoning = summary + .iter() + .map(|x| { + let ReasoningSummary::SummaryText { text } = x; + text.to_owned() + }) + .collect::>() + .join("\n"); + yield Ok(streaming::RawStreamingChoice::Reasoning { reasoning, id: Some(id.to_string()) }) + } + _ => continue } - _ => continue } - } - ItemChunkKind::OutputTextDelta(delta) => { - yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone())) - } - ItemChunkKind::RefusalDelta(delta) => { - yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone())) - } + ItemChunkKind::OutputTextDelta(delta) => { + combined_text.push_str(&delta.delta); + yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone())) + } + ItemChunkKind::RefusalDelta(delta) => { + combined_text.push_str(&delta.delta); + yield Ok(streaming::RawStreamingChoice::Message(delta.delta.clone())) + } - _ => { continue } + _ => { continue } + } } - } - if let StreamingCompletionChunk::Response(chunk) = data { - if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk { - if let Some(usage) = response.usage { - final_usage = usage; + if let StreamingCompletionChunk::Response(chunk) = data { + if let ResponseChunk { kind: ResponseChunkKind::ResponseCompleted, response, .. } = *chunk { + span.record("gen_ai.output.messages", serde_json::to_string(&response.output).unwrap()); + span.record("gen_ai.response.id", response.id); + span.record("gen_ai.response.model", response.model); + if let Some(usage) = response.usage { + final_usage = usage; + } + } else { + continue; } - } else { - continue; } } - } - Err(reqwest_eventsource::Error::StreamEnded) => { - break; - } - Err(error) => { - tracing::error!(?error, "SSE error"); - yield Err(CompletionError::ResponseError(error.to_string())); - break; + Err(reqwest_eventsource::Error::StreamEnded) => { + break; + } + Err(error) => { + tracing::error!(?error, "SSE error"); + yield Err(CompletionError::ResponseError(error.to_string())); + break; + } } } - } - // Ensure event source is closed when stream ends - event_source.close(); + // Ensure event source is closed when stream ends + event_source.close(); - for tool_call in &tool_calls { - yield Ok(tool_call.to_owned()) - } + for tool_call in &tool_calls { + yield Ok(tool_call.to_owned()) + } + + span.record("gen_ai.usage.input_tokens", final_usage.input_tokens); + span.record("gen_ai.usage.output_tokens", final_usage.output_tokens); + tracing::info!("OpenAI stream finished"); - yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { - usage: final_usage.clone() - })); - }); + yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { + usage: final_usage.clone() + })); + }.instrument(span)); - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(stream)) + } } diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs index 9233939a1..60e005fe1 100644 --- a/rig-core/src/providers/openrouter/client.rs +++ b/rig-core/src/providers/openrouter/client.rs @@ -1,5 +1,6 @@ use crate::{ - client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}, + client::{ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError}, + completion::GetTokenUsage, http_client::{self, HttpClientExt}, impl_conversion_traits, }; @@ -217,3 +218,15 @@ impl std::fmt::Display for Usage { ) } } + +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.prompt_tokens as u64; + usage.output_tokens = self.completion_tokens as u64; + usage.total_tokens = self.total_tokens as u64; + + Some(usage) + } +} diff --git a/rig-core/src/providers/openrouter/completion.rs b/rig-core/src/providers/openrouter/completion.rs index ecfcf83de..05fac4502 100644 --- a/rig-core/src/providers/openrouter/completion.rs +++ b/rig-core/src/providers/openrouter/completion.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use tracing::{Instrument, info_span}; use super::client::{ApiErrorResponse, ApiResponse, Client, Usage}; @@ -13,6 +14,7 @@ use serde_json::{Value, json}; use crate::providers::openai::AssistantContent; use crate::providers::openrouter::streaming::FinalCompletionResponse; use crate::streaming::StreamingCompletionResponse; +use crate::telemetry::SpanCombinator; // ================================================================ // OpenRouter Completion API @@ -121,6 +123,43 @@ pub struct Choice { pub finish_reason: Option, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged, rename_all = "snake_case")] +pub enum ToolChoice { + None, + Auto, + Required, + Function(Vec), +} + +impl TryFrom for ToolChoice { + type Error = CompletionError; + + fn try_from(value: crate::message::ToolChoice) -> Result { + let res = match value { + crate::message::ToolChoice::None => Self::None, + crate::message::ToolChoice::Auto => Self::Auto, + crate::message::ToolChoice::Required => Self::Required, + crate::message::ToolChoice::Specific { function_names } => { + let vec: Vec = function_names + .into_iter() + .map(|name| ToolChoiceFunctionKind::Function { name }) + .collect(); + + Self::Function(vec) + } + }; + + Ok(res) + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", content = "function")] +pub enum ToolChoiceFunctionKind { + Function { name: String }, +} + #[derive(Clone)] pub struct CompletionModel { pub(crate) client: Client, @@ -165,11 +204,17 @@ impl CompletionModel { // Combine all messages into a single history full_history.extend(chat_history); + let tool_choice = completion_request + .tool_choice + .map(ToolChoice::try_from) + .transpose()?; + let request = json!({ "model": self.model, "messages": full_history, "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::>() + "tools": completion_request.tools.into_iter().map(crate::providers::openai::completion::ToolDefinition::from).collect::>(), + "tool_choice": tool_choice, }); let request = if let Some(params) = completion_request.additional_params { @@ -191,41 +236,67 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completion", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "openrouter", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - let response = self - .client - .reqwest_client() - .post("/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - - if response.status().is_success() { - match response - .json::>() + async move { + let response = self + .client + .reqwest_client() + .post("/chat/completions") + .json(&request) + .send() .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? - { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "OpenRouter completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - tracing::debug!(target: "rig", - "OpenRouter response: {response:?}"); - response.try_into() + .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + + if response.status().is_success() { + match response + .json::>() + .await + .map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record_token_usage(&response.usage); + span.record_model_output(&response.choices); + span.record("gen_ai.response.id", &response.id); + span.record("gen_ai.response.model_name", &response.model); + + tracing::debug!(target: "rig::completion", + "OpenRouter response: {response:?}"); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } - ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } else { + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] diff --git a/rig-core/src/providers/openrouter/streaming.rs b/rig-core/src/providers/openrouter/streaming.rs index 09e5524a0..8c8383027 100644 --- a/rig-core/src/providers/openrouter/streaming.rs +++ b/rig-core/src/providers/openrouter/streaming.rs @@ -1,5 +1,6 @@ use reqwest_eventsource::{Event, RequestBuilderExt}; use std::collections::HashMap; +use tracing::info_span; use crate::{ completion::GetTokenUsage, @@ -118,13 +119,33 @@ impl super::CompletionModel { completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; let request = json_utils::merge(request, json!({"stream": true})); let builder = self.client.reqwest_post("/chat/completions").json(&request); - send_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "openrouter", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing::Instrument::instrument(send_streaming_request(builder), span).await } } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 2e6d25bba..b6a3ef015 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -23,6 +23,7 @@ use crate::providers::openai::send_compatible_streaming_request; use crate::streaming::StreamingCompletionResponse; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use tracing::{Instrument, info_span}; // ================================================================ // Main Cohere Client @@ -286,6 +287,10 @@ impl CompletionModel { &self, completion_request: CompletionRequest, ) -> Result { + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); + } + // Build up the order of messages (context, chat_history, prompt) let mut partial_history = vec![]; if let Some(docs) = completion_request.normalized_documents() { @@ -398,38 +403,71 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; - let response = self - .client - .reqwest_post("/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "perplexity", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - if response.status().is_success() { - match response - .json::>() + let async_block = async move { + let response = self + .client + .reqwest_post("/chat/completions") + .json(&request) + .send() .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))? - { - ApiResponse::Ok(completion) => { - tracing::info!(target: "rig", - "Perplexity completion token usage: {}", - completion.usage - ); - Ok(completion.try_into()?) + .map_err(|e| http_client::Error::Instance(e.into()))?; + + if response.status().is_success() { + match response + .json::>() + .await + .map_err(|e| http_client::Error::Instance(e.into()))? + { + ApiResponse::Ok(completion) => { + let span = tracing::Span::current(); + span.record("gen_ai.usage.input_tokens", completion.usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + completion.usage.completion_tokens, + ); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&completion.choices).unwrap(), + ); + span.record("gen_ai.response.id", completion.id.to_string()); + span.record("gen_ai.response.model_name", completion.model.to_string()); + Ok(completion.try_into()?) + } + ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), } - ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)), + } else { + Err(CompletionError::ProviderError( + response + .text() + .await + .map_err(|e| http_client::Error::Instance(e.into()))?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) - } + }; + + async_block.instrument(span).await } #[cfg_attr(feature = "worker", worker::send)] @@ -437,13 +475,34 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); let builder = self.client.reqwest_post("/chat/completions").json(&request); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "perplexity", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + send_compatible_streaming_request(builder) + .instrument(span) + .await } } diff --git a/rig-core/src/providers/together/completion.rs b/rig-core/src/providers/together/completion.rs index 7c858fb15..d3ba9a433 100644 --- a/rig-core/src/providers/together/completion.rs +++ b/rig-core/src/providers/together/completion.rs @@ -12,7 +12,9 @@ use crate::{ use super::client::{Client, together_ai_api_types::ApiResponse}; use crate::completion::CompletionRequest; use crate::streaming::StreamingCompletionResponse; +use serde::{Deserialize, Serialize}; use serde_json::json; +use tracing::{Instrument, info_span}; // ================================================================ // Together Completion Models @@ -164,6 +166,11 @@ impl CompletionModel { full_history.extend(chat_history); + let tool_choice = completion_request + .tool_choice + .map(ToolChoice::try_from) + .transpose()?; + let mut request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -176,7 +183,7 @@ impl CompletionModel { "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; request = if let Some(params) = completion_request.additional_params { @@ -197,41 +204,77 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; + let messages_as_json_string = + serde_json::to_string(request.get("messages").unwrap()).unwrap(); + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "together", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = &messages_as_json_string, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing::debug!(target: "rig::completion", "TogetherAI completion request: {messages_as_json_string}"); - let response = self - .client - .reqwest_post("/v1/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - - if response.status().is_success() { - let text = response - .text() + async move { + let response = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request) + .send() .await .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - tracing::debug!(target: "rig", "Together completion error: {}", text); + if response.status().is_success() { + let t = response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?; + tracing::debug!(target: "rig::completion", "TogetherAI completion response: {t}"); - match serde_json::from_str::>(&text)? { - ApiResponse::Ok(response) => { - tracing::info!(target: "rig", - "Together completion token usage: {:?}", - response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) - ); - response.try_into() + match serde_json::from_str::>(&t)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record( + "gen_ai.output.messages", + serde_json::to_string(&response.choices).unwrap(), + ); + span.record("gen_ai.response.id", &response.id); + span.record("gen_ai.response.model_name", &response.model); + if let Some(ref usage) = response.usage { + span.record("gen_ai.usage.input_tokens", usage.prompt_tokens); + span.record( + "gen_ai.usage.output_tokens", + usage.total_tokens - usage.prompt_tokens, + ); + } + response.try_into() + } + ApiResponse::Error(err) => Err(CompletionError::ProviderError(err.error)), } - ApiResponse::Error(err) => Err(CompletionError::ProviderError(err.error)), + } else { + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] @@ -242,3 +285,43 @@ impl completion::CompletionModel for CompletionModel { CompletionModel::stream(self, request).await } } + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged, rename_all = "snake_case")] +pub enum ToolChoice { + None, + Auto, + Function(Vec), +} + +impl TryFrom for ToolChoice { + type Error = CompletionError; + + fn try_from(value: crate::message::ToolChoice) -> Result { + let res = match value { + crate::message::ToolChoice::None => Self::None, + crate::message::ToolChoice::Auto => Self::Auto, + crate::message::ToolChoice::Specific { function_names } => { + let vec: Vec = function_names + .into_iter() + .map(|name| ToolChoiceFunctionKind::Function { name }) + .collect(); + + Self::Function(vec) + } + choice => { + return Err(CompletionError::ProviderError(format!( + "Unsupported tool choice type: {choice:?}" + ))); + } + }; + + Ok(res) + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", content = "function")] +pub enum ToolChoiceFunctionKind { + Function { name: String }, +} diff --git a/rig-core/src/providers/together/streaming.rs b/rig-core/src/providers/together/streaming.rs index 08a54a1f2..9d34f2d9a 100644 --- a/rig-core/src/providers/together/streaming.rs +++ b/rig-core/src/providers/together/streaming.rs @@ -9,12 +9,15 @@ use crate::{ json_utils::merge, }; +use tracing::{Instrument, info_span}; + impl CompletionModel { pub(crate) async fn stream( &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream_tokens": true})); @@ -24,6 +27,27 @@ impl CompletionModel { .reqwest_post("/v1/chat/completions") .json(&request); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "together", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + send_compatible_streaming_request(builder) + .instrument(span) + .await } } diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index fe683be5d..06ca3ec54 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -14,6 +14,7 @@ use crate::completion::CompletionRequest; use crate::providers::openai; use crate::streaming::StreamingCompletionResponse; use serde_json::{Value, json}; +use tracing::{Instrument, info_span}; use xai_api_types::{CompletionResponse, ToolDefinition}; /// xAI completion models as of 2025-06-04 @@ -71,6 +72,11 @@ impl CompletionModel { // Chat history and prompt appear in the order they were provided full_history.extend(chat_history); + let tool_choice = completion_request + .tool_choice + .map(crate::providers::openrouter::ToolChoice::try_from) + .transpose()?; + let mut request = if completion_request.tools.is_empty() { json!({ "model": self.model, @@ -83,7 +89,7 @@ impl CompletionModel { "messages": full_history, "temperature": completion_request.temperature, "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": "auto", + "tool_choice": tool_choice, }) }; @@ -113,33 +119,63 @@ impl completion::CompletionModel for CompletionModel { &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let request = self.create_completion_request(completion_request)?; + let request_messages_json_str = + serde_json::to_string(&request.get("messages").unwrap()).unwrap(); + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "xai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = &request_messages_json_str, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; - let response = self - .client - .reqwest_post("/v1/chat/completions") - .json(&request) - .send() - .await - .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - - if response.status().is_success() { - let body = response - .json::>() + tracing::debug!("xAI completion request: {request_messages_json_str}"); + + async move { + let response = self + .client + .reqwest_post("/v1/chat/completions") + .json(&request) + .send() .await .map_err(|e| CompletionError::HttpError(http_client::Error::Instance(e.into())))?; - match body { - ApiResponse::Ok(completion) => completion.try_into(), - ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())), + if response.status().is_success() { + match response + .json::>() + .await + .map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })? { + ApiResponse::Ok(completion) => completion.try_into(), + ApiResponse::Error(error) => { + Err(CompletionError::ProviderError(error.message())) + } + } + } else { + Err(CompletionError::ProviderError( + response.text().await.map_err(|e| { + CompletionError::HttpError(http_client::Error::Instance(e.into())) + })?, + )) } - } else { - Err(CompletionError::ProviderError( - response.text().await.map_err(|e| { - CompletionError::HttpError(http_client::Error::Instance(e.into())) - })?, - )) } + .instrument(span) + .await } #[cfg_attr(feature = "worker", worker::send)] diff --git a/rig-core/src/providers/xai/streaming.rs b/rig-core/src/providers/xai/streaming.rs index 8ba32cf17..d11b8728c 100644 --- a/rig-core/src/providers/xai/streaming.rs +++ b/rig-core/src/providers/xai/streaming.rs @@ -5,6 +5,7 @@ use crate::providers::openai::send_compatible_streaming_request; use crate::providers::xai::completion::CompletionModel; use crate::streaming::StreamingCompletionResponse; use serde_json::json; +use tracing::{Instrument, info_span}; impl CompletionModel { pub(crate) async fn stream( @@ -12,6 +13,7 @@ impl CompletionModel { completion_request: CompletionRequest, ) -> Result, CompletionError> { + let preamble = completion_request.preamble.clone(); let mut request = self.create_completion_request(completion_request)?; request = merge(request, json!({"stream": true})); @@ -21,6 +23,27 @@ impl CompletionModel { .reqwest_post("/v1/chat/completions") .json(&request); - send_compatible_streaming_request(builder).await + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "xai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + send_compatible_streaming_request(builder) + .instrument(span) + .await } } diff --git a/rig-core/src/telemetry/mod.rs b/rig-core/src/telemetry/mod.rs new file mode 100644 index 000000000..ca8e8f7ca --- /dev/null +++ b/rig-core/src/telemetry/mod.rs @@ -0,0 +1,96 @@ +//! This module primarily concerns being able to orchestrate telemetry across a given pipeline or workflow. +//! This includes tracing, being able to send traces to an OpenTelemetry collector, setting up your +//! agents with the correct tracing style so you can emit the right traces for platforms like Langfuse, +//! and more. + +use crate::completion::GetTokenUsage; +use serde::Serialize; + +pub trait ProviderRequestExt { + type InputMessage: Serialize; + + fn get_input_messages(&self) -> Vec; + fn get_system_prompt(&self) -> Option; + fn get_model_name(&self) -> String; + fn get_prompt(&self) -> Option; +} + +pub trait ProviderResponseExt { + type OutputMessage: Serialize; + type Usage: Serialize; + + fn get_response_id(&self) -> Option; + + fn get_response_model_name(&self) -> Option; + + fn get_output_messages(&self) -> Vec; + + fn get_text_response(&self) -> Option; + + fn get_usage(&self) -> Option; +} + +/// A trait designed specifically to be used with Spans for the purpose of recording telemetry. +/// Nearly all methods +pub trait SpanCombinator { + fn record_token_usage(&self, usage: &U) + where + U: GetTokenUsage; + + fn record_response_metadata(&self, response: &R) + where + R: ProviderResponseExt; + + fn record_model_input(&self, messages: &T) + where + T: Serialize; + + fn record_model_output(&self, messages: &T) + where + T: Serialize; +} + +impl SpanCombinator for tracing::Span { + fn record_token_usage(&self, usage: &U) + where + U: GetTokenUsage, + { + if let Some(usage) = usage.token_usage() { + self.record("gen_ai.usage.input_tokens", usage.input_tokens); + self.record("gen_ai.usage.output_tokens", usage.output_tokens); + } + } + + fn record_response_metadata(&self, response: &R) + where + R: ProviderResponseExt, + { + if let Some(id) = response.get_response_id() { + self.record("gen_ai.response.id", id); + } + + if let Some(model_name) = response.get_response_model_name() { + self.record("gen_ai.response.model_name", model_name); + } + } + + fn record_model_input(&self, input: &T) + where + T: Serialize, + { + let input_as_json_string = + serde_json::to_string(input).expect("Serializing a Rust type to JSON should not break"); + + self.record("gen_ai.input.messages", input_as_json_string); + } + + fn record_model_output(&self, input: &T) + where + T: Serialize, + { + let input_as_json_string = + serde_json::to_string(input).expect("Serializing a Rust type to JSON should not break"); + + self.record("gen_ai.input.messages", input_as_json_string); + } +} diff --git a/rig-core/src/tools/think.rs b/rig-core/src/tools/think.rs index 772044de2..d6019daaf 100644 --- a/rig-core/src/tools/think.rs +++ b/rig-core/src/tools/think.rs @@ -23,7 +23,7 @@ pub struct ThinkError(String); /// It doesn't actually perform any actions or retrieve any information - it just /// provides a space for the model to reason through complex problems. /// -/// This tool is original dervived from the +/// This tool is original derived from the /// [Think tool](https://anthropic.com/engineering/claude-think-tool) blog post from Anthropic. #[derive(Deserialize, Serialize)] pub struct ThinkTool; diff --git a/rig-eternalai/CHANGELOG.md b/rig-eternalai/CHANGELOG.md index 17ea46bc1..c9eef375d 100644 --- a/rig-eternalai/CHANGELOG.md +++ b/rig-eternalai/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.8](https://github.com/0xPlaygrounds/rig/compare/rig-eternalai-v0.3.7...rig-eternalai-v0.3.8) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.3.7](https://github.com/0xPlaygrounds/rig/compare/rig-eternalai-v0.3.6...rig-eternalai-v0.3.7) - 2025-09-15 ### Other diff --git a/rig-eternalai/Cargo.toml b/rig-eternalai/Cargo.toml index 4548fdbd4..131a3b654 100644 --- a/rig-eternalai/Cargo.toml +++ b/rig-eternalai/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-eternalai" -version = "0.3.7" +version = "0.3.8" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -8,7 +8,7 @@ description = "EternalAI model provider Rig integration." repository = "https://github.com/0xPlaygrounds/rig" [dependencies] -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } ethers = { workspace = true } reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } diff --git a/rig-fastembed/CHANGELOG.md b/rig-fastembed/CHANGELOG.md index 85a9c23db..59aa57ad7 100644 --- a/rig-fastembed/CHANGELOG.md +++ b/rig-fastembed/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.12](https://github.com/0xPlaygrounds/rig/compare/rig-fastembed-v0.2.11...rig-fastembed-v0.2.12) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.2.11](https://github.com/0xPlaygrounds/rig/compare/rig-fastembed-v0.2.10...rig-fastembed-v0.2.11) - 2025-09-15 ### Other diff --git a/rig-fastembed/Cargo.toml b/rig-fastembed/Cargo.toml index 27ce2228d..03a31bdda 100644 --- a/rig-fastembed/Cargo.toml +++ b/rig-fastembed/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-fastembed" -version = "0.2.11" +version = "0.2.12" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -8,7 +8,7 @@ description = "Rig vector store index integration for Fastembed. https://github. repository = "https://github.com/0xPlaygrounds/rig" [dependencies] -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tracing = { workspace = true } diff --git a/rig-helixdb/.gitignore b/rig-helixdb/.gitignore new file mode 100644 index 000000000..ed68366d4 --- /dev/null +++ b/rig-helixdb/.gitignore @@ -0,0 +1 @@ +.helix diff --git a/rig-helixdb/Cargo.toml b/rig-helixdb/Cargo.toml new file mode 100644 index 000000000..9b0fbfa3f --- /dev/null +++ b/rig-helixdb/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "rig-helixdb" +version = "0.1.0" +edition.workspace = true +license = "MIT" +readme = "README.md" +description = "Rig vector store index integration for HelixDB." +repository = "https://github.com/0xPlaygrounds/rig" + +[dependencies] +helix-rs = "0.1.9" +serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true +rig-core = { path = "../rig-core", version = "0.21.0" } + +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } + +[[example]] +name = "vector_search_helixdb" +required-features = ["rig-core/derive"] diff --git a/rig-helixdb/README.md b/rig-helixdb/README.md new file mode 100644 index 000000000..b453bbb04 --- /dev/null +++ b/rig-helixdb/README.md @@ -0,0 +1,35 @@ +# Rig HelixDB integration +This crate integrates HelixDB into Rig, allowing you to easily use RAG with this database. + +## Installation +To install this crate, run the following command in a Rust project directory which will add `rig-helixdb` as a dependency (requires `rig-core` added for intended usage): +```bash +cargo add rig-helixdb +``` + +There's a few different ways you can run HelixDB: +- Through HelixDB's cloud offering +- Running it locally through their `helix start` command (requires the Helix CLI to be installed). + - For local dev, you will likely want to use `helix push dev` for continuous iteration. + +## How to run the example +Before running the example, you will need to ensure that you are running an instance of HelixDB which you can do with `helix dockerdev run`. + +Once done, you will then need to deploy your queries/schema. **The queries/schema in the `examples/helixdb-cfg` folder are a required minimum to be use this integration.** `rig-helixdb` also additionally provides a way to get a manual handle on the client yourself so that you can invoke your own queries should you need to. + +Assuming `rig-helixdb` is your current working directory, to deploy a minimum viable configuration for HelixDB (with `rig-helixdb`) you will need to `cd` into the `helixdb-cfg` folder and then run the following: +```bash +helix push dev +``` + +This will then deploy the queries/schema into your instance. + +To run the example, add your OpenAI API key as an environment variable: +```bash +export OPENAI_API_KEY=my_key +``` + +Finally, use the following command below to run the example: +```bash +cargo run --example vector_search_helixdb --features rig-core/derive +``` diff --git a/rig-helixdb/examples/helixdb-cfg/db/queries.hx b/rig-helixdb/examples/helixdb-cfg/db/queries.hx new file mode 100644 index 000000000..213082d03 --- /dev/null +++ b/rig-helixdb/examples/helixdb-cfg/db/queries.hx @@ -0,0 +1,7 @@ +QUERY InsertVector (vector: [F64], doc: String, json_payload: String) => + AddV(vector, { doc: doc, json_payload: json_payload }) + RETURN doc + +QUERY VectorSearch(vector: [F64], limit: U64, threshold: F64) => + vec_docs <- SearchV(vector, limit) + RETURN vec_docs diff --git a/rig-helixdb/examples/helixdb-cfg/db/schema.hx b/rig-helixdb/examples/helixdb-cfg/db/schema.hx new file mode 100644 index 000000000..4119bffcf --- /dev/null +++ b/rig-helixdb/examples/helixdb-cfg/db/schema.hx @@ -0,0 +1,4 @@ +V::Document { + doc: String, + json_payload: String +} diff --git a/rig-helixdb/examples/helixdb-cfg/helix.toml b/rig-helixdb/examples/helixdb-cfg/helix.toml new file mode 100644 index 000000000..3066c7583 --- /dev/null +++ b/rig-helixdb/examples/helixdb-cfg/helix.toml @@ -0,0 +1,15 @@ +[project] +name = "helixdb-cfg" +queries = "./db/" + +[local.dev] +port = 6969 +build_mode = "debug" + +[local.dev.vector_config] +m = 16 +ef_construction = 128 +ef_search = 768 +db_max_size_gb = 10 + +[cloud] diff --git a/rig-helixdb/examples/vector_search_helixdb.rs b/rig-helixdb/examples/vector_search_helixdb.rs new file mode 100644 index 000000000..016735e8a --- /dev/null +++ b/rig-helixdb/examples/vector_search_helixdb.rs @@ -0,0 +1,79 @@ +use helix_rs::{HelixDB, HelixDBClient}; +use rig::{ + Embed, + client::{EmbeddingsClient, ProviderClient}, + embeddings::EmbeddingsBuilder, + vector_store::{InsertDocuments, VectorSearchRequest, VectorStoreIndex}, +}; +use rig_helixdb::HelixDBVectorStore; +use serde::{Deserialize, Serialize}; + +// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition` +// and tag that field with `#[embed]`. +// We are not going to store the definitions on our database so we skip the `Serialize` trait +#[derive(Embed, Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Default)] +struct WordDefinition { + word: String, + #[serde(skip)] // we don't want to serialize this field, we use only to create embeddings + #[embed] + definition: String, +} + +impl std::fmt::Display for WordDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.word) + } +} + +#[tokio::main] +async fn main() { + let openai_model = + rig::providers::openai::Client::from_env().embedding_model("text-embedding-ada-002"); + + let helixdb_client = HelixDB::new(None, Some(6969), None); // Uses default port 6969 + let vector_store = HelixDBVectorStore::new(helixdb_client, openai_model.clone()); + + let words = vec![ + WordDefinition { + word: "flurbo".to_string(), + definition: "1. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + }, + WordDefinition { + word: "glarb-glarb".to_string(), + definition: "1. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + }, + WordDefinition { + word: "linglingdong".to_string(), + definition: "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + }]; + + let documents = EmbeddingsBuilder::new(openai_model) + .documents(words) + .unwrap() + .build() + .await + .expect("Failed to create embeddings"); + + vector_store.insert_documents(documents).await.unwrap(); + + let query = "What is a flurbo?"; + let vector_req = VectorSearchRequest::builder() + .query(query) + .samples(5) + .build() + .unwrap(); + + let docs = vector_store + .top_n::(vector_req) + .await + .unwrap(); + + for doc in docs { + println!( + "Vector found with id: {id} and score: {score} and word def: {doc}", + id = doc.1, + score = doc.0, + doc = doc.2 + ) + } +} diff --git a/rig-helixdb/src/lib.rs b/rig-helixdb/src/lib.rs new file mode 100644 index 000000000..614aa05b0 --- /dev/null +++ b/rig-helixdb/src/lib.rs @@ -0,0 +1,181 @@ +use helix_rs::HelixDBClient; +use rig::{ + embeddings::EmbeddingModel, + vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex}, +}; +use serde::{Deserialize, Serialize}; + +/// A client for easily carrying out Rig-related vector store operations. +/// +/// If you are unsure what type to use for the client, `helix_rs::HelixDB` is the typical default. +/// +/// Usage: +/// ```rust +/// let openai_model = +/// rig::providers::openai::Client::from_env().embedding_model("text-embedding-ada-002"); +/// +/// let helixdb_client = HelixDB::new(None, Some(6969), None); +/// let vector_store = HelixDBVectorStore::new(helixdb_client, openai_model.clone()); +/// ``` +pub struct HelixDBVectorStore { + client: C, + model: E, +} + +/// The result of a query. Only used internally as this is a representative type required for the relevant HelixDB query (`VectorSearch`). +#[derive(Deserialize, Serialize, Clone, Debug)] +struct QueryResult { + id: String, + score: f64, + doc: String, + json_payload: String, +} + +/// An input query. Only used internally as this is a representative type required for the relevant HelixDB query (`VectorSearch`). +#[derive(Deserialize, Serialize, Clone, Debug)] +struct QueryInput { + vector: Vec, + limit: u64, + threshold: f64, +} + +impl QueryInput { + /// Makes a new instance of `QueryInput`. + pub(crate) fn new(vector: Vec, limit: u64, threshold: f64) -> Self { + Self { + vector, + limit, + threshold, + } + } +} + +impl HelixDBVectorStore +where + C: HelixDBClient + Send, + E: EmbeddingModel, +{ + pub fn new(client: C, model: E) -> Self { + Self { client, model } + } + + pub fn client(&self) -> &C { + &self.client + } +} + +impl InsertDocuments for HelixDBVectorStore +where + C: HelixDBClient + Send + Sync, + E: EmbeddingModel + Send + Sync, +{ + async fn insert_documents( + &self, + documents: Vec<(Doc, rig::OneOrMany)>, + ) -> Result<(), VectorStoreError> { + #[derive(Serialize, Deserialize, Clone, Debug, Default)] + struct QueryInput { + vector: Vec, + doc: String, + json_payload: String, + } + + #[derive(Serialize, Deserialize, Clone, Debug, Default)] + struct QueryOutput { + doc: String, + } + + for (document, embeddings) in documents { + let json_document: serde_json::Value = serde_json::to_value(&document).unwrap(); + let json_document_as_string = serde_json::to_string(&json_document).unwrap(); + + for embedding in embeddings { + let embedded_text = embedding.document; + let vector: Vec = embedding.vec; + + let query = QueryInput { + vector, + doc: embedded_text, + json_payload: json_document_as_string.clone(), + }; + + self.client + .query::("InsertVector", &query) + .await + .inspect_err(|x| println!("Error: {x}")) + .map_err(|x| VectorStoreError::DatastoreError(x.to_string().into()))?; + } + } + Ok(()) + } +} + +impl VectorStoreIndex for HelixDBVectorStore +where + C: HelixDBClient + Send + Sync, + E: EmbeddingModel + Send + Sync, +{ + async fn top_n serde::Deserialize<'a> + Send>( + &self, + req: rig::vector_store::VectorSearchRequest, + ) -> Result, rig::vector_store::VectorStoreError> { + let vector = self.model.embed_text(req.query()).await?.vec; + + let query_input = + QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default()); + + #[derive(Serialize, Deserialize, Debug)] + struct VecResult { + vec_docs: Vec, + } + + let result: VecResult = self + .client + .query::("VectorSearch", &query_input) + .await + .unwrap(); + + let docs = result + .vec_docs + .into_iter() + .map(|x| { + let doc: T = serde_json::from_str(&x.json_payload)?; + + // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score. + Ok((-(x.score - 1.), x.id, doc)) + }) + .collect::, VectorStoreError>>()?; + + Ok(docs) + } + + async fn top_n_ids( + &self, + req: rig::vector_store::VectorSearchRequest, + ) -> Result, rig::vector_store::VectorStoreError> { + let vector = self.model.embed_text(req.query()).await?.vec; + + let query_input = + QueryInput::new(vector, req.samples(), req.threshold().unwrap_or_default()); + + #[derive(Serialize, Deserialize, Debug)] + struct VecResult { + vec_docs: Vec, + } + + let result: VecResult = self + .client + .query::("VectorSearch", &query_input) + .await + .unwrap(); + + // HelixDB gives us the cosine distance, so we need to use `-(cosine_dist - 1)` to get the cosine similarity score. + let docs = result + .vec_docs + .into_iter() + .map(|x| Ok((-(x.score - 1.), x.id))) + .collect::, VectorStoreError>>()?; + + Ok(docs) + } +} diff --git a/rig-lancedb/CHANGELOG.md b/rig-lancedb/CHANGELOG.md index f8a6abc5a..bfbabe80e 100644 --- a/rig-lancedb/CHANGELOG.md +++ b/rig-lancedb/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.23](https://github.com/0xPlaygrounds/rig/compare/rig-lancedb-v0.2.22...rig-lancedb-v0.2.23) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.2.22](https://github.com/0xPlaygrounds/rig/compare/rig-lancedb-v0.2.21...rig-lancedb-v0.2.22) - 2025-09-15 ### Other diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index e74c11f55..9f4ec06ec 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-lancedb" -version = "0.2.22" +version = "0.2.23" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -9,7 +9,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] lancedb = { workspace = true } -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } arrow-array = { workspace = true } serde_json = { workspace = true } serde = { workspace = true } diff --git a/rig-milvus/CHANGELOG.md b/rig-milvus/CHANGELOG.md index 458aef8bb..2ef9627e9 100644 --- a/rig-milvus/CHANGELOG.md +++ b/rig-milvus/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.12](https://github.com/0xPlaygrounds/rig/compare/rig-milvus-v0.1.11...rig-milvus-v0.1.12) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.1.11](https://github.com/0xPlaygrounds/rig/compare/rig-milvus-v0.1.10...rig-milvus-v0.1.11) - 2025-09-15 ### Other diff --git a/rig-milvus/Cargo.toml b/rig-milvus/Cargo.toml index dd078fd13..bb3517d84 100644 --- a/rig-milvus/Cargo.toml +++ b/rig-milvus/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "rig-milvus" -version = "0.1.11" +version = "0.1.12" edition = { workspace = true } description = "Milvus vector store implementation for the rig framework" license = "MIT" [dependencies] reqwest = { workspace = true, features = ["json"] } -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } serde = { workspace = true } serde_json = { workspace = true } uuid = { workspace = true, features = ["v4"] } diff --git a/rig-mongodb/CHANGELOG.md b/rig-mongodb/CHANGELOG.md index a0b1f6c8a..8297ccd64 100644 --- a/rig-mongodb/CHANGELOG.md +++ b/rig-mongodb/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.23](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.2.22...rig-mongodb-v0.2.23) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.2.22](https://github.com/0xPlaygrounds/rig/compare/rig-mongodb-v0.2.21...rig-mongodb-v0.2.22) - 2025-09-15 ### Other diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 1bf652d03..23613e371 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-mongodb" -version = "0.2.22" +version = "0.2.23" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -12,7 +12,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = { workspace = true } mongodb = { workspace = true } -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tracing = { workspace = true } diff --git a/rig-neo4j/CHANGELOG.md b/rig-neo4j/CHANGELOG.md index ff2f2fac2..d78f93769 100644 --- a/rig-neo4j/CHANGELOG.md +++ b/rig-neo4j/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.7](https://github.com/0xPlaygrounds/rig/compare/rig-neo4j-v0.3.6...rig-neo4j-v0.3.7) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.3.6](https://github.com/0xPlaygrounds/rig/compare/rig-neo4j-v0.3.5...rig-neo4j-v0.3.6) - 2025-09-15 ### Other diff --git a/rig-neo4j/Cargo.toml b/rig-neo4j/Cargo.toml index ba4f862bb..161280986 100644 --- a/rig-neo4j/Cargo.toml +++ b/rig-neo4j/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-neo4j" -version = "0.3.6" +version = "0.3.7" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -12,7 +12,7 @@ repository = "https://github.com/0xPlaygrounds/rig" [dependencies] futures = { workspace = true } neo4rs = { workspace = true } -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tracing = { workspace = true } diff --git a/rig-postgres/CHANGELOG.md b/rig-postgres/CHANGELOG.md index 4a34ca792..e9585c717 100644 --- a/rig-postgres/CHANGELOG.md +++ b/rig-postgres/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.21](https://github.com/0xPlaygrounds/rig/compare/rig-postgres-v0.1.20...rig-postgres-v0.1.21) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.1.20](https://github.com/0xPlaygrounds/rig/compare/rig-postgres-v0.1.19...rig-postgres-v0.1.20) - 2025-09-15 ### Other diff --git a/rig-postgres/Cargo.toml b/rig-postgres/Cargo.toml index 3e4e7995c..b50b3ffbb 100644 --- a/rig-postgres/Cargo.toml +++ b/rig-postgres/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-postgres" -version = "0.1.20" +version = "0.1.21" edition = { workspace = true } description = "PostgreSQL-based vector store implementation for the rig framework" license = "MIT" @@ -8,7 +8,7 @@ readme = "README.md" repository = "https://github.com/0xPlaygrounds/rig" [dependencies] -rig-core = { path = "../rig-core", version = "0.20.0", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.21.0", features = ["derive"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } diff --git a/rig-qdrant/CHANGELOG.md b/rig-qdrant/CHANGELOG.md index 8b588f4a4..d3a62f93a 100644 --- a/rig-qdrant/CHANGELOG.md +++ b/rig-qdrant/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.26](https://github.com/0xPlaygrounds/rig/compare/rig-qdrant-v0.1.25...rig-qdrant-v0.1.26) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.1.25](https://github.com/0xPlaygrounds/rig/compare/rig-qdrant-v0.1.24...rig-qdrant-v0.1.25) - 2025-09-15 ### Other diff --git a/rig-qdrant/Cargo.toml b/rig-qdrant/Cargo.toml index 7bc8eca2b..3561923e7 100644 --- a/rig-qdrant/Cargo.toml +++ b/rig-qdrant/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-qdrant" -version = "0.1.25" +version = "0.1.26" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -8,7 +8,7 @@ description = "Rig vector store index integration for Qdrant. https://qdrant.tec repository = "https://github.com/0xPlaygrounds/rig" [dependencies] -rig-core = { path = "../rig-core", version = "0.20.0" } +rig-core = { path = "../rig-core", version = "0.21.0" } serde_json = { workspace = true } serde = { workspace = true } qdrant-client = { workspace = true } diff --git a/rig-s3vectors/CHANGELOG.md b/rig-s3vectors/CHANGELOG.md index a3a3b9056..83cb828d9 100644 --- a/rig-s3vectors/CHANGELOG.md +++ b/rig-s3vectors/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.9](https://github.com/0xPlaygrounds/rig/compare/rig-s3vectors-v0.1.8...rig-s3vectors-v0.1.9) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.1.8](https://github.com/0xPlaygrounds/rig/compare/rig-s3vectors-v0.1.7...rig-s3vectors-v0.1.8) - 2025-09-15 ### Other diff --git a/rig-s3vectors/Cargo.toml b/rig-s3vectors/Cargo.toml index 796cf4660..0716156dd 100644 --- a/rig-s3vectors/Cargo.toml +++ b/rig-s3vectors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-s3vectors" -version = "0.1.8" +version = "0.1.9" edition = { workspace = true } description = "AWS S3Vectors vector store implementation for the rig framework" license = "MIT" @@ -11,7 +11,7 @@ aws-smithy-types = { workspace = true, features = [ "serde-deserialize", "serde-serialize", ] } -rig-core = { path = "../rig-core", version = "0.20.0", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.21.0", features = ["derive"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tracing = { workspace = true } diff --git a/rig-scylladb/CHANGELOG.md b/rig-scylladb/CHANGELOG.md index be4124621..25ca4a365 100644 --- a/rig-scylladb/CHANGELOG.md +++ b/rig-scylladb/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.12](https://github.com/0xPlaygrounds/rig/compare/rig-scylladb-v0.1.11...rig-scylladb-v0.1.12) - 2025-09-29 + +### Fixed + +- ci lints ([#832](https://github.com/0xPlaygrounds/rig/pull/832)) + ## [0.1.11](https://github.com/0xPlaygrounds/rig/compare/rig-scylladb-v0.1.10...rig-scylladb-v0.1.11) - 2025-09-15 ### Other diff --git a/rig-scylladb/Cargo.toml b/rig-scylladb/Cargo.toml index 3796732e3..dbabb5f66 100644 --- a/rig-scylladb/Cargo.toml +++ b/rig-scylladb/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-scylladb" -version = "0.1.11" +version = "0.1.12" edition = { workspace = true } license = "MIT" readme = "README.md" @@ -8,7 +8,7 @@ description = "ScyllaDB vector store index integration for Rig. High-performance repository = "https://github.com/0xPlaygrounds/rig" [dependencies] -rig-core = { path = "../rig-core", version = "0.20.0", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.21.0", features = ["derive"] } serde_json = { workspace = true } serde = { workspace = true, features = ["derive"] } scylla = { workspace = true } diff --git a/rig-sqlite/CHANGELOG.md b/rig-sqlite/CHANGELOG.md index 6699f6b88..5a0d287d0 100644 --- a/rig-sqlite/CHANGELOG.md +++ b/rig-sqlite/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.23](https://github.com/0xPlaygrounds/rig/compare/rig-sqlite-v0.1.22...rig-sqlite-v0.1.23) - 2025-09-29 + +### Other + +- updated the following local packages: rig-core + ## [0.1.22](https://github.com/0xPlaygrounds/rig/compare/rig-sqlite-v0.1.21...rig-sqlite-v0.1.22) - 2025-09-15 ### Other diff --git a/rig-sqlite/Cargo.toml b/rig-sqlite/Cargo.toml index 74a8e5a8e..d79647459 100644 --- a/rig-sqlite/Cargo.toml +++ b/rig-sqlite/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rig-sqlite" -version = "0.1.22" +version = "0.1.23" edition = { workspace = true } description = "SQLite-based vector store implementation for the rig framework" license = "MIT" @@ -9,7 +9,7 @@ license = "MIT" doctest = false [dependencies] -rig-core = { path = "../rig-core", version = "0.20.0", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.21.0", features = ["derive"] } rusqlite = { workspace = true, features = ["bundled"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } diff --git a/rig-surrealdb/CHANGELOG.md b/rig-surrealdb/CHANGELOG.md index bbadfa230..46029f536 100644 --- a/rig-surrealdb/CHANGELOG.md +++ b/rig-surrealdb/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.18](https://github.com/0xPlaygrounds/rig/compare/rig-surrealdb-v0.1.17...rig-surrealdb-v0.1.18) - 2025-09-29 + +### Fixed + +- *(rig-944)* surrealdb WHERE clause causes no results ([#821](https://github.com/0xPlaygrounds/rig/pull/821)) + ## [0.1.17](https://github.com/0xPlaygrounds/rig/compare/rig-surrealdb-v0.1.16...rig-surrealdb-v0.1.17) - 2025-09-15 ### Other diff --git a/rig-surrealdb/Cargo.toml b/rig-surrealdb/Cargo.toml index 826a41f20..11813687b 100644 --- a/rig-surrealdb/Cargo.toml +++ b/rig-surrealdb/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "rig-surrealdb" -version = "0.1.17" +version = "0.1.18" edition = { workspace = true } description = "SurrealDB vector store implementation for the rig framework" license = "MIT" [dependencies] surrealdb = { workspace = true, features = ["protocol-ws", "kv-mem"] } -rig-core = { path = "../rig-core", version = "0.20.0", features = ["derive"] } +rig-core = { path = "../rig-core", version = "0.21.0", features = ["derive"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tracing = { workspace = true } diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 000000000..ff100edcb --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "1.90.0" From ac3dee8b0e271fe3dfdff8bf21bf1bd80288ce10 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Mon, 6 Oct 2025 17:19:54 -0400 Subject: [PATCH 08/20] Clippy --- rig-core/src/providers/azure.rs | 10 +++++----- rig-core/src/providers/ollama.rs | 1 - rig-core/src/providers/openrouter/client.rs | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index d0e1d74c1..0e49e3e08 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -206,7 +206,7 @@ where } impl Client { - fn post_reqwest(&self, url: String) -> reqwest::RequestBuilder { + fn reqwest_post(&self, url: String) -> reqwest::RequestBuilder { let (key, val) = self.auth.as_header(); self.http_client.post(url).header(key, val) @@ -220,7 +220,7 @@ impl Client { ) .replace("//", "/"); - self.post_reqwest(url) + self.reqwest_post(url) } fn post_chat_completion(&self, deployment_id: &str) -> reqwest::RequestBuilder { @@ -230,7 +230,7 @@ impl Client { ) .replace("//", "/"); - self.post_reqwest(url) + self.reqwest_post(url) } fn post_transcription(&self, deployment_id: &str) -> reqwest::RequestBuilder { @@ -240,7 +240,7 @@ impl Client { ) .replace("//", "/"); - self.post_reqwest(url) + self.reqwest_post(url) } #[cfg(feature = "image")] @@ -251,7 +251,7 @@ impl Client { ) .replace("//", "/"); - self.post_reqwest(url) + self.reqwest_post(url) } } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index b1bf268b7..867894147 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -63,7 +63,6 @@ use serde_json::{Value, json}; use std::{convert::TryFrom, str::FromStr}; use tracing::info_span; use tracing_futures::Instrument; -use url::Url; // ---------- Main Client ---------- const OLLAMA_API_BASE_URL: &str = "http://localhost:11434"; diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs index 60e005fe1..7d3a7614e 100644 --- a/rig-core/src/providers/openrouter/client.rs +++ b/rig-core/src/providers/openrouter/client.rs @@ -1,5 +1,5 @@ use crate::{ - client::{ClientBuilderError, CompletionClient, ProviderClient, VerifyClient, VerifyError}, + client::{CompletionClient, ProviderClient, VerifyClient, VerifyError}, completion::GetTokenUsage, http_client::{self, HttpClientExt}, impl_conversion_traits, From 81b553235357bb73b0e780feee1dec67d827b384 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Mon, 6 Oct 2025 17:45:27 -0400 Subject: [PATCH 09/20] Fix duplication --- rig-core/src/providers/gemini/completion.rs | 5 --- .../src/providers/openai/completion/mod.rs | 40 +------------------ rig-postgres/tests/integration_tests.rs | 3 +- rig-qdrant/tests/integration_tests.rs | 3 +- rig-scylladb/tests/integration_tests.rs | 6 +-- rig-sqlite/tests/integration_test.rs | 3 +- 6 files changed, 6 insertions(+), 54 deletions(-) diff --git a/rig-core/src/providers/gemini/completion.rs b/rig-core/src/providers/gemini/completion.rs index c91c28863..f9ef11d61 100644 --- a/rig-core/src/providers/gemini/completion.rs +++ b/rig-core/src/providers/gemini/completion.rs @@ -806,11 +806,6 @@ pub mod gemini_api_types { "Raw files not supported, encode as base64 first".into(), )); } - DocumentSourceKind::Raw(_) => { - return Err(message::MessageError::ConversionError( - "Raw files not supported, encode as base64 first".into(), - )); - } _ => { return Err(message::MessageError::ConversionError( "Document has no body".to_string(), diff --git a/rig-core/src/providers/openai/completion/mod.rs b/rig-core/src/providers/openai/completion/mod.rs index 8e994ff69..62a6ef270 100644 --- a/rig-core/src/providers/openai/completion/mod.rs +++ b/rig-core/src/providers/openai/completion/mod.rs @@ -6,7 +6,7 @@ use super::{ApiErrorResponse, ApiResponse, Client, streaming::StreamingCompletio use crate::completion::{ CompletionError, CompletionRequest as CoreCompletionRequest, GetTokenUsage, }; -use crate::http_client::HttpClientExt; +use crate::http_client::{self, HttpClientExt}; use crate::message::{AudioMediaType, DocumentSourceKind, ImageDetail, MimeType}; use crate::one_or_many::string_or_one_or_many; use crate::telemetry::{ProviderResponseExt, SpanCombinator}; @@ -913,44 +913,6 @@ impl crate::telemetry::ProviderRequestExt for CompletionRequest { } } -impl crate::telemetry::ProviderRequestExt for CompletionRequest { - type InputMessage = Message; - - fn get_input_messages(&self) -> Vec { - self.messages.clone() - } - - fn get_system_prompt(&self) -> Option { - let first_message = self.messages.first()?; - - let Message::System { ref content, .. } = first_message.clone() else { - return None; - }; - - let SystemContent { text, .. } = content.first(); - - Some(text) - } - - fn get_prompt(&self) -> Option { - let last_message = self.messages.last()?; - - let Message::User { ref content, .. } = last_message.clone() else { - return None; - }; - - let UserContent::Text { text } = content.first() else { - return None; - }; - - Some(text) - } - - fn get_model_name(&self) -> String { - self.model.clone() - } -} - impl CompletionModel { pub fn into_agent_builder(self) -> crate::agent::AgentBuilder { crate::agent::AgentBuilder::new(self) diff --git a/rig-postgres/tests/integration_tests.rs b/rig-postgres/tests/integration_tests.rs index 0be88d667..f3dc6d5d6 100644 --- a/rig-postgres/tests/integration_tests.rs +++ b/rig-postgres/tests/integration_tests.rs @@ -52,8 +52,7 @@ async fn vector_search_test() { let openai_mock = create_openai_mock_service().await; let openai_client = rig::providers::openai::Client::builder("TEST") .base_url(&openai_mock.base_url()) - .build() - .unwrap(); + .build(); let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-qdrant/tests/integration_tests.rs b/rig-qdrant/tests/integration_tests.rs index 926d5f8ad..937de2ff5 100644 --- a/rig-qdrant/tests/integration_tests.rs +++ b/rig-qdrant/tests/integration_tests.rs @@ -142,8 +142,7 @@ async fn vector_search_test() { // Initialize OpenAI client let openai_client = openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); // let openai_client = openai::Client::from_env(); let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); diff --git a/rig-scylladb/tests/integration_tests.rs b/rig-scylladb/tests/integration_tests.rs index 5591b973a..54b28e561 100644 --- a/rig-scylladb/tests/integration_tests.rs +++ b/rig-scylladb/tests/integration_tests.rs @@ -75,8 +75,7 @@ async fn vector_search_test() { let openai_mock = create_openai_mock_service().await; let openai_client = rig::providers::openai::Client::builder("TEST") .base_url(&openai_mock.base_url()) - .build() - .unwrap(); + .build(); let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_ADA_002); @@ -348,8 +347,7 @@ async fn test_mock_server_setup() { let server = create_openai_mock_service().await; let openai_client = rig::providers::openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); let model = openai_client.embedding_model(rig::providers::openai::TEXT_EMBEDDING_ADA_002); // Test that we can create embeddings with the mock diff --git a/rig-sqlite/tests/integration_test.rs b/rig-sqlite/tests/integration_test.rs index 443ac03cc..bd5cc2bc9 100644 --- a/rig-sqlite/tests/integration_test.rs +++ b/rig-sqlite/tests/integration_test.rs @@ -139,8 +139,7 @@ async fn vector_search_test() { // Initialize OpenAI client let openai_client = openai::Client::builder("TEST") .base_url(&server.base_url()) - .build() - .unwrap(); + .build(); // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); From a9c1d30148ac9e0dd5e03663b7b78cd6a8681dd7 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Mon, 6 Oct 2025 17:54:49 -0400 Subject: [PATCH 10/20] Fix CI? --- rig-core/src/http_client.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 6a5a033f9..7a8bbab73 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -118,9 +118,7 @@ impl HttpClientExt for reqwest::Client { async move { let response = req.send().await.map_err(instance_error)?; - let mut res = Response::builder() - .status(response.status()) - .version(response.version()); + let mut res = Response::builder().status(response.status()); if let Some(hs) = res.headers_mut() { *hs = response.headers().clone(); @@ -160,8 +158,10 @@ impl HttpClientExt for reqwest::Client { *hs = response.headers().clone(); } - let stream: ByteStream = - Box::pin(response.bytes_stream().map(|r| r.map_err(instance_error))); + let stream: ByteStream = { + use futures::TryStreamExt; + Box::pin(response.bytes_stream().map_err(instance_error).boxed()) + }; Ok(res.body(stream)?) } From 1ab933e57ef5443fda12298fb23ee76a08d79f89 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Tue, 7 Oct 2025 11:06:44 -0400 Subject: [PATCH 11/20] Wasm compat, rename `HttpClientExt` methods --- rig-core/src/http_client.rs | 32 ++++++++++++++------ rig-core/src/providers/anthropic/client.rs | 8 ++--- rig-core/src/providers/azure.rs | 4 +-- rig-core/src/providers/cohere/client.rs | 4 +-- rig-core/src/providers/deepseek.rs | 4 +-- rig-core/src/providers/galadriel.rs | 4 +-- rig-core/src/providers/gemini/client.rs | 8 ++--- rig-core/src/providers/groq.rs | 2 +- rig-core/src/providers/huggingface/client.rs | 4 +-- rig-core/src/providers/hyperbolic.rs | 2 +- rig-core/src/providers/mira.rs | 4 +-- rig-core/src/providers/mistral/client.rs | 6 ++-- rig-core/src/providers/moonshot.rs | 2 +- rig-core/src/providers/ollama.rs | 4 +-- rig-core/src/providers/openai/client.rs | 8 +++-- rig-core/src/providers/openrouter/client.rs | 2 +- rig-core/src/providers/together/client.rs | 6 ++-- 17 files changed, 60 insertions(+), 44 deletions(-) diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 7a8bbab73..d8a9c5116 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -5,6 +5,16 @@ use reqwest::Body; use std::future::Future; use std::pin::Pin; +#[cfg(not(target_arch = "wasm32"))] +pub trait RigSend: Send {} +#[cfg(target_arch = "wasm32")] +pub trait RigSend {} + +#[cfg(not(target_arch = "wasm32"))] +impl RigSend for T {} +#[cfg(target_arch = "wasm32")] +impl RigSend for T {} + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Http error: {0}")] @@ -52,15 +62,17 @@ pub fn with_bearer_auth(req: Builder, auth: &str) -> Result { } pub trait HttpClientExt: Send + Sync { - fn request( + fn send( &self, req: Request, ) -> impl Future>>> + Send where - T: Into + Send, - U: From + Send; + T: Into, + T: RigSend, + U: From, + U: RigSend + 'static; - fn request_streaming( + fn send_streaming( &self, req: Request, ) -> impl Future> + Send @@ -69,7 +81,7 @@ pub trait HttpClientExt: Send + Sync { fn get(&self, uri: Uri) -> impl Future>>> + Send where - T: From + Send, + T: From + Send + 'static, { async { let req = Request::builder() @@ -77,7 +89,7 @@ pub trait HttpClientExt: Send + Sync { .uri(uri) .body(NoBody)?; - self.request(req).await + self.send(req).await } } @@ -88,20 +100,20 @@ pub trait HttpClientExt: Send + Sync { ) -> impl Future>>> + Send where T: Into + Send, - R: From + Send, + R: From + Send + 'static, { async { let req = Request::builder() .method(Method::POST) .uri(uri) .body(body)?; - self.request(req).await + self.send(req).await } } } impl HttpClientExt for reqwest::Client { - fn request( + fn send( &self, req: Request, ) -> impl Future>>> + Send @@ -134,7 +146,7 @@ impl HttpClientExt for reqwest::Client { } } - fn request_streaming( + fn send_streaming( &self, req: Request, ) -> impl Future> + Send diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs index e75c92147..4648d43fd 100644 --- a/rig-core/src/providers/anthropic/client.rs +++ b/rig-core/src/providers/anthropic/client.rs @@ -156,9 +156,9 @@ where ) -> Result>, http_client::Error> where U: Into + Send, - V: From + Send, + V: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } pub async fn send_streaming( @@ -168,7 +168,7 @@ where where U: Into, { - self.http_client.request_streaming(req).await + self.http_client.send_streaming(req).await } pub(crate) fn post(&self, path: &str) -> http_client::Builder { @@ -258,7 +258,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(http_client::Error::from)?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { http::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index 0e49e3e08..0ab3af8c6 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -199,9 +199,9 @@ where ) -> http_client::Result>> where U: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index 6526bbbf2..91a1bd64a 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -132,9 +132,9 @@ where ) -> http_client::Result>> where U: Into + Send, - V: From + Send, + V: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } pub fn embeddings( diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 5bc45a74e..09040ff81 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -152,9 +152,9 @@ where ) -> http_client::Result>> where U: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index c29858970..5ab3bd402 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -153,9 +153,9 @@ where ) -> http_client::Result>> where U: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index e521d58c7..e61cd4f01 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -179,9 +179,9 @@ where ) -> http_client::Result>> where U: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } @@ -303,14 +303,14 @@ where .get("/v1beta/models") .body(http_client::NoBody) .map_err(|e| VerifyError::HttpError(e.into()))?; - let response = self.http_client.request::<_, Vec>(req).await?; + let response = self.http_client.send::<_, Vec>(req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), reqwest::StatusCode::FORBIDDEN => Err(VerifyError::InvalidAuthentication), reqwest::StatusCode::INTERNAL_SERVER_ERROR | reqwest::StatusCode::SERVICE_UNAVAILABLE => { - let text = String::from_utf8_lossy(&response.into_body().await?).into(); + let text = http_client::text(response).await?; Err(VerifyError::ProviderError(text)) } _ => { diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 5e61ffe5a..47868814d 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -223,7 +223,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(http_client::Error::from)?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/huggingface/client.rs b/rig-core/src/providers/huggingface/client.rs index 304922f31..bbe695836 100644 --- a/rig-core/src/providers/huggingface/client.rs +++ b/rig-core/src/providers/huggingface/client.rs @@ -269,9 +269,9 @@ where ) -> http_client::Result>> where U: Into + Send, - V: From + Send, + V: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index 8b15776c4..bc27a70dd 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -194,7 +194,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(http_client::Error::from)?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 08640d9ad..24897e4ea 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -225,7 +225,7 @@ where .map_err(http_client::Error::Protocol) })?; - let response = self.http_client.request(req).await?; + let response = self.http_client.send(req).await?; let status = response.status(); @@ -310,7 +310,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(http_client::Error::from)?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/mistral/client.rs b/rig-core/src/providers/mistral/client.rs index 48f127b31..0e10cee31 100644 --- a/rig-core/src/providers/mistral/client.rs +++ b/rig-core/src/providers/mistral/client.rs @@ -119,9 +119,9 @@ where ) -> http_client::Result>> where Body: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } @@ -199,7 +199,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(|e| VerifyError::HttpError(e.into()))?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index fc7cdea31..ad99ac821 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -190,7 +190,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(http_client::Error::from)?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 867894147..07a378983 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -218,7 +218,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(http_client::Error::from)?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), @@ -331,7 +331,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .body(body) .map_err(|e| EmbeddingError::HttpError(e.into()))?; - let response = HttpClientExt::request(&self.client.http_client, req).await?; + let response = HttpClientExt::send(&self.client.http_client, req).await?; if !response.status().is_success() { let text = http_client::text(response).await?; diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index 33502a978..fe1599ab4 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -126,12 +126,16 @@ where pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + dbg!(&url); + http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + dbg!(&url); + http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) } @@ -141,9 +145,9 @@ where ) -> http_client::Result>> where U: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs index 7d3a7614e..f44446e0b 100644 --- a/rig-core/src/providers/openrouter/client.rs +++ b/rig-core/src/providers/openrouter/client.rs @@ -166,7 +166,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(|e| VerifyError::HttpError(e.into()))?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), diff --git a/rig-core/src/providers/together/client.rs b/rig-core/src/providers/together/client.rs index 29e9b93f9..a673f530a 100644 --- a/rig-core/src/providers/together/client.rs +++ b/rig-core/src/providers/together/client.rs @@ -146,9 +146,9 @@ where ) -> http_client::Result>> where U: Into + Send, - R: From + Send, + R: From + Send + 'static, { - self.http_client.request(req).await + self.http_client.send(req).await } } @@ -243,7 +243,7 @@ impl VerifyClient for Client { .body(http_client::NoBody) .map_err(|e| VerifyError::HttpError(e.into()))?; - let response = HttpClientExt::request(&self.http_client, req).await?; + let response = HttpClientExt::send(&self.http_client, req).await?; match response.status() { reqwest::StatusCode::OK => Ok(()), From 4a98b905e8128dc99d235b1c428d7b611b016e3c Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Tue, 7 Oct 2025 11:19:11 -0400 Subject: [PATCH 12/20] Squash merge main into HTTP-Trait-w-dep --- flake.nix | 12 +++++++++--- rig-helixdb/src/lib.rs | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/flake.nix b/flake.nix index e8b4a183b..09d10da90 100644 --- a/flake.nix +++ b/flake.nix @@ -21,21 +21,27 @@ let overlays = [ (import rust-overlay) ]; pkgs = import nixpkgs { - inherit system overlays; - }; + inherit system overlays; + }; + rustToolchain = pkgs.rust-bin.stable."1.90.0".default.override { + targets = [ "wasm32-unknown-unknown" ]; + }; in { devShells.default = with pkgs; mkShell { buildInputs = [ pkg-config cmake + just openssl sqlite postgresql protobuf - rust-bin.stable."1.90.0".default + rustToolchain + wasm-bindgen-cli + wasm-pack ]; OPENSSL_DEV = openssl.dev; diff --git a/rig-helixdb/src/lib.rs b/rig-helixdb/src/lib.rs index 614aa05b0..b082cb25d 100644 --- a/rig-helixdb/src/lib.rs +++ b/rig-helixdb/src/lib.rs @@ -138,6 +138,7 @@ where let docs = result .vec_docs .into_iter() + .filter(|x| -(x.score - 1.) >= req.threshold().unwrap_or_default()) .map(|x| { let doc: T = serde_json::from_str(&x.json_payload)?; @@ -173,6 +174,7 @@ where let docs = result .vec_docs .into_iter() + .filter(|x| -(x.score - 1.) >= req.threshold().unwrap_or_default()) .map(|x| Ok((-(x.score - 1.), x.id))) .collect::, VectorStoreError>>()?; From 4d497a353d9720fe4483d0808ddac418dd3e3fd7 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 15:54:39 -0400 Subject: [PATCH 13/20] Wasm compatibility --- rig-core/src/agent/completion.rs | 18 ++- rig-core/src/agent/prompt_request/mod.rs | 23 +-- .../src/agent/prompt_request/streaming.rs | 7 +- rig-core/src/cli_chatbot.rs | 3 +- rig-core/src/client/completion.rs | 5 +- rig-core/src/client/verify.rs | 8 +- rig-core/src/completion/request.rs | 48 ++++--- rig-core/src/embeddings/embedding.rs | 134 +++++++++++++----- rig-core/src/extractor.rs | 22 +-- rig-core/src/http_client.rs | 96 ++++++------- rig-core/src/lib.rs | 1 + rig-core/src/pipeline/agent_ops.rs | 34 +++-- rig-core/src/pipeline/op.rs | 69 ++++----- rig-core/src/pipeline/try_op.rs | 42 +++--- .../src/providers/anthropic/completion.rs | 8 +- .../src/providers/anthropic/decoders/sse.rs | 43 ++++-- rig-core/src/providers/anthropic/streaming.rs | 2 +- rig-core/src/providers/cohere/client.rs | 3 +- rig-core/src/providers/cohere/embeddings.rs | 3 +- rig-core/src/providers/gemini/client.rs | 3 +- rig-core/src/providers/gemini/embedding.rs | 5 +- .../src/providers/gemini/transcription.rs | 3 +- .../providers/huggingface/transcription.rs | 3 +- rig-core/src/providers/openai/client.rs | 2 - rig-core/src/streaming.rs | 23 +-- rig-core/src/tool.rs | 61 ++++---- rig-core/src/transcription.rs | 21 ++- rig-core/src/vector_store/mod.rs | 34 +++-- rig-core/src/wasm_compat.rs | 45 ++++++ 29 files changed, 470 insertions(+), 299 deletions(-) create mode 100644 rig-core/src/wasm_compat.rs diff --git a/rig-core/src/agent/completion.rs b/rig-core/src/agent/completion.rs index adc85259f..5bb39426e 100644 --- a/rig-core/src/agent/completion.rs +++ b/rig-core/src/agent/completion.rs @@ -9,6 +9,7 @@ use crate::{ streaming::{StreamingChat, StreamingCompletion, StreamingPrompt}, tool::ToolSet, vector_store::{VectorStoreError, request::VectorSearchRequest}, + wasm_compat::WasmCompatSend, }; use futures::{StreamExt, TryStreamExt, stream}; use std::{collections::HashMap, sync::Arc}; @@ -85,7 +86,7 @@ where { async fn completion( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, ) -> Result, CompletionError> { let prompt = prompt.into(); @@ -228,7 +229,7 @@ where { fn prompt( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, ) -> PromptRequest<'_, prompt_request::Standard, M, ()> { PromptRequest::new(self, prompt) } @@ -242,7 +243,7 @@ where #[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))] fn prompt( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, ) -> PromptRequest<'_, prompt_request::Standard, M, ()> { PromptRequest::new(*self, prompt) } @@ -256,7 +257,7 @@ where #[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))] async fn chat( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, mut chat_history: Vec, ) -> Result { PromptRequest::new(self, prompt) @@ -271,7 +272,7 @@ where { async fn stream_completion( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, ) -> Result, CompletionError> { // Reuse the existing completion implementation to build the request @@ -285,7 +286,10 @@ where M: CompletionModel + 'static, M::StreamingResponse: GetTokenUsage, { - fn stream_prompt(&self, prompt: impl Into + Send) -> StreamingPromptRequest { + fn stream_prompt( + &self, + prompt: impl Into + WasmCompatSend, + ) -> StreamingPromptRequest { let arc = Arc::new(self.clone()); StreamingPromptRequest::new(arc, prompt) } @@ -298,7 +302,7 @@ where { fn stream_chat( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, ) -> StreamingPromptRequest { let arc = Arc::new(self.clone()); diff --git a/rig-core/src/agent/prompt_request/mod.rs b/rig-core/src/agent/prompt_request/mod.rs index 7a400073d..e03dc55a5 100644 --- a/rig-core/src/agent/prompt_request/mod.rs +++ b/rig-core/src/agent/prompt_request/mod.rs @@ -15,6 +15,7 @@ use crate::{ completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage}, message::{AssistantContent, UserContent}, tool::ToolSetError, + wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}, }; use super::Agent; @@ -136,7 +137,7 @@ where // dead code allowed because of functions being left empty to allow for users to not have to implement every single function /// Trait for per-request hooks to observe tool call events. -pub trait PromptHook: Clone + Send + Sync +pub trait PromptHook: Clone + WasmCompatSend + WasmCompatSync where M: CompletionModel, { @@ -146,7 +147,7 @@ where &self, prompt: &Message, history: &[Message], - ) -> impl Future + Send { + ) -> impl Future + WasmCompatSend { async {} } @@ -156,13 +157,17 @@ where &self, prompt: &Message, response: &crate::completion::CompletionResponse, - ) -> impl Future + Send { + ) -> impl Future + WasmCompatSend { async {} } #[allow(unused_variables)] /// Called before a tool is invoked. - fn on_tool_call(&self, tool_name: &str, args: &str) -> impl Future + Send { + fn on_tool_call( + &self, + tool_name: &str, + args: &str, + ) -> impl Future + WasmCompatSend { async {} } @@ -173,7 +178,7 @@ where tool_name: &str, args: &str, result: &str, - ) -> impl Future + Send { + ) -> impl Future + WasmCompatSend { async {} } } @@ -189,10 +194,10 @@ where P: PromptHook + 'static, { type Output = Result; - type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent + type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent fn into_future(self) -> Self::IntoFuture { - self.send().boxed() + Box::pin(self.send()) } } @@ -202,10 +207,10 @@ where P: PromptHook + 'static, { type Output = Result; - type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent + type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent fn into_future(self) -> Self::IntoFuture { - self.send().boxed() + Box::pin(self.send()) } } diff --git a/rig-core/src/agent/prompt_request/streaming.rs b/rig-core/src/agent/prompt_request/streaming.rs index eb331ade3..fd64f6fc1 100644 --- a/rig-core/src/agent/prompt_request/streaming.rs +++ b/rig-core/src/agent/prompt_request/streaming.rs @@ -3,6 +3,7 @@ use crate::{ completion::GetTokenUsage, message::{AssistantContent, Reasoning, ToolResultContent, UserContent}, streaming::{StreamedAssistantContent, StreamingCompletion}, + wasm_compat::{WasmBoxedFuture, WasmCompatSend}, }; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; @@ -110,7 +111,7 @@ where impl StreamingPromptRequest where M: CompletionModel + 'static, - ::StreamingResponse: Send + GetTokenUsage, + ::StreamingResponse: WasmCompatSend + GetTokenUsage, P: StreamingPromptHook, { /// Create a new PromptRequest with the given prompt and model @@ -395,11 +396,11 @@ where impl IntoFuture for StreamingPromptRequest where M: CompletionModel + 'static, - ::StreamingResponse: Send, + ::StreamingResponse: WasmCompatSend, P: StreamingPromptHook + 'static, { type Output = StreamingResult; // what `.await` returns - type IntoFuture = Pin + Send>>; + type IntoFuture = WasmBoxedFuture<'static, Self::Output>; fn into_future(self) -> Self::IntoFuture { // Wrap send() in a future, because send() returns a stream immediately diff --git a/rig-core/src/cli_chatbot.rs b/rig-core/src/cli_chatbot.rs index 4ca023e0a..de8be50bc 100644 --- a/rig-core/src/cli_chatbot.rs +++ b/rig-core/src/cli_chatbot.rs @@ -3,6 +3,7 @@ use crate::{ completion::{Chat, CompletionError, CompletionModel, PromptError, Usage}, message::Message, streaming::{StreamedAssistantContent, StreamingPrompt}, + wasm_compat::WasmCompatSend, }; use futures::StreamExt; use std::io::{self, Write}; @@ -60,7 +61,7 @@ where impl CliChat for AgentImpl where - M: CompletionModel + 'static, + M: CompletionModel + WasmCompatSend + 'static, { async fn request( &mut self, diff --git a/rig-core/src/client/completion.rs b/rig-core/src/client/completion.rs index 5a8ca2dd3..97e8ccbf5 100644 --- a/rig-core/src/client/completion.rs +++ b/rig-core/src/client/completion.rs @@ -6,6 +6,7 @@ use crate::completion::{ }; use crate::extractor::ExtractorBuilder; use crate::streaming::StreamingCompletionResponse; +use crate::wasm_compat::WasmCompatSend; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::future::Future; @@ -72,7 +73,7 @@ impl CompletionModel for CompletionModelHandle<'_> { fn completion( &self, request: CompletionRequest, - ) -> impl Future, CompletionError>> + Send + ) -> impl Future, CompletionError>> + WasmCompatSend { self.inner.completion(request) } @@ -82,7 +83,7 @@ impl CompletionModel for CompletionModelHandle<'_> { request: CompletionRequest, ) -> impl Future< Output = Result, CompletionError>, - > + Send { + > + WasmCompatSend { self.inner.stream(request) } } diff --git a/rig-core/src/client/verify.rs b/rig-core/src/client/verify.rs index 81a565bc3..08d3f799c 100644 --- a/rig-core/src/client/verify.rs +++ b/rig-core/src/client/verify.rs @@ -1,8 +1,8 @@ use crate::{ client::{AsVerify, ProviderClient}, http_client, + wasm_compat::{WasmBoxedFuture, WasmCompatSend}, }; -use futures::future::BoxFuture; use thiserror::Error; #[derive(Debug, Error)] @@ -23,19 +23,19 @@ pub enum VerifyError { /// Clone is required for conversions between client types. pub trait VerifyClient: ProviderClient + Clone { /// Verify the configuration. - fn verify(&self) -> impl Future> + Send; + fn verify(&self) -> impl Future> + WasmCompatSend; } pub trait VerifyClientDyn: ProviderClient { /// Verify the configuration. - fn verify(&self) -> BoxFuture<'_, Result<(), VerifyError>>; + fn verify(&self) -> WasmBoxedFuture<'_, Result<(), VerifyError>>; } impl VerifyClientDyn for T where T: VerifyClient, { - fn verify(&self) -> BoxFuture<'_, Result<(), VerifyError>> { + fn verify(&self) -> WasmBoxedFuture<'_, Result<(), VerifyError>> { Box::pin(self.verify()) } } diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 241fe591f..f3654fc83 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -67,6 +67,7 @@ use super::message::{AssistantContent, DocumentMediaType}; use crate::client::completion::CompletionModelHandle; use crate::message::ToolChoice; use crate::streaming::StreamingCompletionResponse; +use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}; use crate::{OneOrMany, http_client, streaming}; use crate::{ json_utils, @@ -96,9 +97,15 @@ pub enum CompletionError { #[error("UrlError: {0}")] UrlError(#[from] url::ParseError), + #[cfg(not(target_family = "wasm"))] /// Error building the completion request #[error("RequestError: {0}")] - RequestError(#[from] Box), + RequestError(#[from] Box), + + #[cfg(target_family = "wasm")] + /// Error building the completion request + #[error("RequestError: {0}")] + RequestError(#[from] Box), /// Error parsing the completion response #[error("ResponseError: {0}")] @@ -172,7 +179,7 @@ pub struct ToolDefinition { // Implementations // ================================================================ /// Trait defining a high-level LLM simple prompt interface (i.e.: prompt in, response out). -pub trait Prompt: Send + Sync { +pub trait Prompt: WasmCompatSend + WasmCompatSync { /// Send a simple prompt to the underlying completion model. /// /// If the completion model's response is a message, then it is returned as a string. @@ -183,12 +190,12 @@ pub trait Prompt: Send + Sync { /// If the tool does not exist, or the tool call fails, then an error is returned. fn prompt( &self, - prompt: impl Into + Send, - ) -> impl std::future::IntoFuture, IntoFuture: Send>; + prompt: impl Into + WasmCompatSend, + ) -> impl std::future::IntoFuture, IntoFuture: WasmCompatSend>; } /// Trait defining a high-level LLM chat interface (i.e.: prompt and chat history in, response out). -pub trait Chat: Send + Sync { +pub trait Chat: WasmCompatSend + WasmCompatSync { /// Send a prompt with optional chat history to the underlying completion model. /// /// If the completion model's response is a message, then it is returned as a string. @@ -199,9 +206,9 @@ pub trait Chat: Send + Sync { /// If the tool does not exist, or the tool call fails, then an error is returned. fn chat( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, - ) -> impl std::future::IntoFuture, IntoFuture: Send>; + ) -> impl std::future::IntoFuture, IntoFuture: WasmCompatSend>; } /// Trait defining a low-level LLM completion interface @@ -219,9 +226,10 @@ pub trait Completion { /// contain the `preamble` provided when creating the agent. fn completion( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, - ) -> impl std::future::Future, CompletionError>> + Send; + ) -> impl std::future::Future, CompletionError>> + + WasmCompatSend; } /// General completion response struct that contains the high-level completion choice @@ -315,14 +323,14 @@ impl AddAssign for Usage { /// Trait defining a completion model that can be used to generate completion responses. /// This trait is meant to be implemented by the user to define a custom completion model, /// either from a third party provider (e.g.: OpenAI) or a local model. -pub trait CompletionModel: Clone + Send + Sync { +pub trait CompletionModel: Clone + WasmCompatSend + WasmCompatSync { /// The raw response type returned by the underlying completion model. - type Response: Send + Sync + Serialize + DeserializeOwned; + type Response: WasmCompatSend + WasmCompatSync + Serialize + DeserializeOwned; /// The raw response type returned by the underlying completion model when streaming. type StreamingResponse: Clone + Unpin - + Send - + Sync + + WasmCompatSend + + WasmCompatSync + Serialize + DeserializeOwned + GetTokenUsage; @@ -333,30 +341,30 @@ pub trait CompletionModel: Clone + Send + Sync { request: CompletionRequest, ) -> impl std::future::Future< Output = Result, CompletionError>, - > + Send; + > + WasmCompatSend; fn stream( &self, request: CompletionRequest, ) -> impl std::future::Future< Output = Result, CompletionError>, - > + Send; + > + WasmCompatSend; /// Generates a completion request builder for the given `prompt`. fn completion_request(&self, prompt: impl Into) -> CompletionRequestBuilder { CompletionRequestBuilder::new(self.clone(), prompt) } } -pub trait CompletionModelDyn: Send + Sync { +pub trait CompletionModelDyn: WasmCompatSend + WasmCompatSync { fn completion( &self, request: CompletionRequest, - ) -> BoxFuture<'_, Result, CompletionError>>; + ) -> WasmBoxedFuture<'_, Result, CompletionError>>; fn stream( &self, request: CompletionRequest, - ) -> BoxFuture<'_, Result, CompletionError>>; + ) -> WasmBoxedFuture<'_, Result, CompletionError>>; fn completion_request( &self, @@ -372,7 +380,7 @@ where fn completion( &self, request: CompletionRequest, - ) -> BoxFuture<'_, Result, CompletionError>> { + ) -> WasmBoxedFuture<'_, Result, CompletionError>> { Box::pin(async move { self.completion(request) .await @@ -387,7 +395,7 @@ where fn stream( &self, request: CompletionRequest, - ) -> BoxFuture<'_, Result, CompletionError>> { + ) -> WasmBoxedFuture<'_, Result, CompletionError>> { Box::pin(async move { let resp = self.stream(request).await?; let inner = resp.inner; diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index df8cafc87..e3e76513a 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -6,10 +6,15 @@ //! Finally, the module defines the [EmbeddingError] enum, which represents various errors that //! can occur during embedding generation or processing. -use futures::future::BoxFuture; +use crate::{http_client, if_not_wasm, if_wasm, wasm_compat::*}; use serde::{Deserialize, Serialize}; +if_wasm! { + use futures::future::LocalBoxFuture; +} -use crate::http_client; +if_not_wasm! { + use futures::future::BoxFuture; +} #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { @@ -24,10 +29,16 @@ pub enum EmbeddingError { #[error("UrlError: {0}")] UrlError(#[from] url::ParseError), + #[cfg(not(target_family = "wasm"))] /// Error processing the document for embedding #[error("DocumentError: {0}")] DocumentError(Box), + #[cfg(target_family = "wasm")] + /// Error processing the document for embedding + #[error("DocumentError: {0}")] + DocumentError(Box), + /// Error parsing the completion response #[error("ResponseError: {0}")] ResponseError(String), @@ -38,7 +49,7 @@ pub enum EmbeddingError { } /// Trait for embedding models that can generate embeddings for documents. -pub trait EmbeddingModel: Clone + Sync + Send { +pub trait EmbeddingModel: Clone + WasmCompatSend + WasmCompatSync { /// The maximum number of documents that can be embedded in a single request. const MAX_DOCUMENTS: usize; @@ -48,14 +59,14 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// Embed multiple text documents in a single request fn embed_texts( &self, - texts: impl IntoIterator + Send, - ) -> impl std::future::Future, EmbeddingError>> + Send; + texts: impl IntoIterator + WasmCompatSend, + ) -> impl std::future::Future, EmbeddingError>> + WasmCompatSend; /// Embed a single text document. fn embed_text( &self, text: &str, - ) -> impl std::future::Future> + Send { + ) -> impl std::future::Future> + WasmCompatSend { async { Ok(self .embed_texts(vec![text.to_string()]) @@ -66,42 +77,99 @@ pub trait EmbeddingModel: Clone + Sync + Send { } } -pub trait EmbeddingModelDyn: Sync + Send { - fn max_documents(&self) -> usize; - fn ndims(&self) -> usize; - fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result>; - fn embed_texts( - &self, - texts: Vec, - ) -> BoxFuture<'_, Result, EmbeddingError>>; +if_wasm! { + pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync { + fn max_documents(&self) -> usize; + fn ndims(&self) -> usize; + fn embed_text<'a>( + &'a self, + text: &'a str, + ) -> LocalBoxFuture<'a, Result>; + fn embed_texts( + &self, + texts: Vec, + ) -> LocalBoxFuture<'_, Result, EmbeddingError>>; + } + } -impl EmbeddingModelDyn for T -where - T: EmbeddingModel, -{ - fn max_documents(&self) -> usize { - T::MAX_DOCUMENTS - } +if_wasm! { + impl EmbeddingModelDyn for T + where + T: EmbeddingModel + WasmCompatSend + WasmCompatSync, + { + fn max_documents(&self) -> usize { + T::MAX_DOCUMENTS + } + + fn ndims(&self) -> usize { + self.ndims() + } - fn ndims(&self) -> usize { - self.ndims() + fn embed_text<'a>( + &'a self, + text: &'a str, + ) -> LocalBoxFuture<'a, Result> { + Box::pin(self.embed_text(text)) + } + + fn embed_texts( + &self, + texts: Vec, + ) -> LocalBoxFuture<'_, Result, EmbeddingError>> { + Box::pin(self.embed_texts(texts.into_iter().collect::>())) + } } +} - fn embed_text<'a>(&'a self, text: &'a str) -> BoxFuture<'a, Result> { - Box::pin(self.embed_text(text)) +if_not_wasm! { + pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync { + fn max_documents(&self) -> usize; + fn ndims(&self) -> usize; + fn embed_text<'a>( + &'a self, + text: &'a str, + ) -> LocalBoxFuture<'a, Result>; + fn embed_texts( + &self, + texts: Vec, + ) -> BoxFuture<'_, Result, EmbeddingError>>; } - fn embed_texts( - &self, - texts: Vec, - ) -> BoxFuture<'_, Result, EmbeddingError>> { - Box::pin(self.embed_texts(texts.into_iter().collect::>())) + +} + +if_not_wasm! { + impl EmbeddingModelDyn for T + where + T: EmbeddingModel + WasmCompatSend + WasmCompatSync, + { + fn max_documents(&self) -> usize { + T::MAX_DOCUMENTS + } + + fn ndims(&self) -> usize { + self.ndims() + } + + fn embed_text<'a>( + &'a self, + text: &'a str, + ) -> LocalBoxFuture<'a, Result> { + Box::pin(self.embed_text(text)) + } + + fn embed_texts( + &self, + texts: Vec, + ) -> LocalBoxFuture<'_, Result, EmbeddingError>> { + Box::pin(self.embed_texts(texts.into_iter().collect::>())) + } } } /// Trait for embedding models that can generate embeddings for images. -pub trait ImageEmbeddingModel: Clone + Sync + Send { +pub trait ImageEmbeddingModel: Clone + WasmCompatSend + WasmCompatSync { /// The maximum number of images that can be embedded in a single request. const MAX_DOCUMENTS: usize; @@ -111,14 +179,14 @@ pub trait ImageEmbeddingModel: Clone + Sync + Send { /// Embed multiple images in a single request from bytes. fn embed_images( &self, - images: impl IntoIterator> + Send, + images: impl IntoIterator> + WasmCompatSend, ) -> impl std::future::Future, EmbeddingError>> + Send; /// Embed a single image from bytes. fn embed_image<'a>( &'a self, bytes: &'a [u8], - ) -> impl std::future::Future> + Send { + ) -> impl std::future::Future> + WasmCompatSend { async move { Ok(self .embed_images(vec![bytes.to_owned()]) diff --git a/rig-core/src/extractor.rs b/rig-core/src/extractor.rs index 7e4e86c77..191035963 100644 --- a/rig-core/src/extractor.rs +++ b/rig-core/src/extractor.rs @@ -39,6 +39,7 @@ use crate::{ completion::{Completion, CompletionError, CompletionModel, ToolDefinition}, message::{AssistantContent, Message, ToolCall, ToolChoice, ToolFunction}, tool::Tool, + wasm_compat::{WasmCompatSend, WasmCompatSync}, }; const SUBMIT_TOOL_NAME: &str = "submit"; @@ -59,7 +60,7 @@ pub enum ExtractionError { pub struct Extractor where M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, + T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync, { agent: Agent, _t: PhantomData, @@ -69,7 +70,7 @@ where impl Extractor where M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Send + Sync, + T: JsonSchema + for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync, { /// Attempts to extract data from the given text with a number of retries. /// @@ -77,7 +78,10 @@ where /// if the model does not call the `submit` tool. /// /// The number of retries is determined by the `retries` field on the Extractor struct. - pub async fn extract(&self, text: impl Into + Send) -> Result { + pub async fn extract( + &self, + text: impl Into + WasmCompatSend, + ) -> Result { let mut last_error = None; let text_message = text.into(); @@ -108,7 +112,7 @@ where /// The number of retries is determined by the `retries` field on the Extractor struct. pub async fn extract_with_chat_history( &self, - text: impl Into + Send, + text: impl Into + WasmCompatSend, chat_history: Vec, ) -> Result { let mut last_error = None; @@ -135,7 +139,7 @@ where async fn extract_json( &self, - text: impl Into + Send, + text: impl Into + WasmCompatSend, messages: Vec, ) -> Result { let response = self.agent.completion(text, messages).await?.send().await?; @@ -205,7 +209,7 @@ where pub struct ExtractorBuilder where M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static, { agent_builder: AgentBuilder, _t: PhantomData, @@ -215,7 +219,7 @@ where impl ExtractorBuilder where M: CompletionModel, - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync + 'static, + T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync + 'static, { pub fn new(model: M) -> Self { Self { @@ -283,7 +287,7 @@ where #[derive(Deserialize, Serialize)] struct SubmitTool where - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, + T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync, { _t: PhantomData, } @@ -294,7 +298,7 @@ struct SubmitError; impl Tool for SubmitTool where - T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync, + T: JsonSchema + for<'a> Deserialize<'a> + Serialize + WasmCompatSend + WasmCompatSync, { const NAME: &'static str = SUBMIT_TOOL_NAME; type Error = SubmitError; diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index d8a9c5116..6e6e2d744 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -1,38 +1,48 @@ use bytes::Bytes; -use futures::stream::{BoxStream, StreamExt}; +#[cfg(not(target_family = "wasm"))] +use futures::stream::BoxStream; +use futures::stream::Stream; pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder}; use reqwest::Body; use std::future::Future; use std::pin::Pin; -#[cfg(not(target_arch = "wasm32"))] -pub trait RigSend: Send {} -#[cfg(target_arch = "wasm32")] -pub trait RigSend {} - -#[cfg(not(target_arch = "wasm32"))] -impl RigSend for T {} -#[cfg(target_arch = "wasm32")] -impl RigSend for T {} +use crate::wasm_compat::*; #[derive(Debug, thiserror::Error)] pub enum Error { #[error("Http error: {0}")] Protocol(#[from] http::Error), + #[cfg(not(target_family = "wasm"))] #[error("Http client error: {0}")] Instance(#[from] Box), + + #[cfg(target_family = "wasm")] + #[error("Http client error: {0}")] + Instance(#[from] Box), } pub type Result = std::result::Result; +#[cfg(not(target_family = "wasm"))] fn instance_error(error: E) -> Error { Error::Instance(error.into()) } -pub type LazyBytes = Pin> + Send + 'static>>; -pub type LazyBody = Pin> + Send + 'static>>; +#[cfg(target_family = "wasm")] +fn instance_error(error: E) -> Error { + Error::Instance(error.into()) +} + +pub type LazyBytes = WasmBoxedFuture<'static, Result>; +pub type LazyBody = WasmBoxedFuture<'static, Result>; +#[cfg(not(target_family = "wasm"))] pub type ByteStream = BoxStream<'static, Result>; + +#[cfg(target_family = "wasm")] +pub type ByteStream = Pin> + 'static>>; + pub type StreamingResponse = Response; pub struct NoBody; @@ -61,65 +71,33 @@ pub fn with_bearer_auth(req: Builder, auth: &str) -> Result { Ok(req.header("Authorization", auth_header)) } -pub trait HttpClientExt: Send + Sync { +pub trait HttpClientExt: WasmCompatSend + WasmCompatSync { fn send( &self, req: Request, - ) -> impl Future>>> + Send + ) -> impl Future>>> + WasmCompatSend + 'static where T: Into, - T: RigSend, + T: WasmCompatSend, U: From, - U: RigSend + 'static; + U: WasmCompatSend + 'static; fn send_streaming( &self, req: Request, - ) -> impl Future> + Send + ) -> impl Future> + WasmCompatSend + 'static where T: Into; - - fn get(&self, uri: Uri) -> impl Future>>> + Send - where - T: From + Send + 'static, - { - async { - let req = Request::builder() - .method(Method::GET) - .uri(uri) - .body(NoBody)?; - - self.send(req).await - } - } - - fn post( - &self, - uri: Uri, - body: T, - ) -> impl Future>>> + Send - where - T: Into + Send, - R: From + Send + 'static, - { - async { - let req = Request::builder() - .method(Method::POST) - .uri(uri) - .body(body)?; - self.send(req).await - } - } } impl HttpClientExt for reqwest::Client { fn send( &self, req: Request, - ) -> impl Future>>> + Send + ) -> impl Future>>> + WasmCompatSend + 'static where T: Into, - U: From + Send, + U: From + WasmCompatSend, { let (parts, body) = req.into_parts(); let req = self @@ -136,8 +114,12 @@ impl HttpClientExt for reqwest::Client { *hs = response.headers().clone(); } - let body: LazyBody = Box::pin(async move { - let bytes = response.bytes().await.map_err(instance_error)?; + let body: LazyBody = Box::pin(async { + let bytes = response + .bytes() + .await + .map_err(|e| Error::Instance(e.into()))?; + let body = U::from(bytes); Ok(body) }); @@ -149,7 +131,7 @@ impl HttpClientExt for reqwest::Client { fn send_streaming( &self, req: Request, - ) -> impl Future> + Send + ) -> impl Future> + WasmCompatSend + 'static where T: Into, { @@ -162,17 +144,21 @@ impl HttpClientExt for reqwest::Client { async move { let response: reqwest::Response = req.send().await.map_err(instance_error)?; + #[cfg(not(target_family = "wasm"))] let mut res = Response::builder() .status(response.status()) .version(response.version()); + #[cfg(target_family = "wasm")] + let mut res = Response::builder().status(response.status()); + if let Some(hs) = res.headers_mut() { *hs = response.headers().clone(); } let stream: ByteStream = { use futures::TryStreamExt; - Box::pin(response.bytes_stream().map_err(instance_error).boxed()) + Box::pin(response.bytes_stream().map_err(instance_error)) }; Ok(res.body(stream)?) diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 47874a8a7..23e7301cd 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -130,6 +130,7 @@ pub mod tool; pub mod tools; pub mod transcription; pub mod vector_store; +pub mod wasm_compat; // Re-export commonly used types and traits pub use completion::message; diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index a00c603d0..c0773e18c 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -5,6 +5,7 @@ use crate::{ extractor::{ExtractionError, Extractor}, message::Message, vector_store::{self, request::VectorSearchRequest}, + wasm_compat::{WasmCompatSend, WasmCompatSync}, }; use super::Op; @@ -33,8 +34,8 @@ where impl Op for Lookup where I: vector_store::VectorStoreIndex, - In: Into + Send + Sync, - T: Send + Sync + for<'a> serde::Deserialize<'a>, + In: Into + WasmCompatSend + WasmCompatSync, + T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>, { type Input = In; type Output = Result, vector_store::VectorStoreError>; @@ -60,8 +61,8 @@ where pub fn lookup(index: I, n: usize) -> Lookup where I: vector_store::VectorStoreIndex, - In: Into + Send + Sync, - T: Send + Sync + for<'a> serde::Deserialize<'a>, + In: Into + WasmCompatSend + WasmCompatSync, + T: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>, { Lookup::new(index, n) } @@ -82,13 +83,16 @@ impl Prompt { impl Op for Prompt where - P: completion::Prompt + Send + Sync, - In: Into + Send + Sync, + P: completion::Prompt + WasmCompatSend + WasmCompatSync, + In: Into + WasmCompatSend + WasmCompatSync, { type Input = In; type Output = Result; - fn call(&self, input: Self::Input) -> impl std::future::Future + Send { + fn call( + &self, + input: Self::Input, + ) -> impl std::future::Future + WasmCompatSend { self.prompt.prompt(input.into()).into_future() } } @@ -99,7 +103,7 @@ where pub fn prompt(model: P) -> Prompt where P: completion::Prompt, - In: Into + Send + Sync, + In: Into + WasmCompatSend + WasmCompatSync, { Prompt::new(model) } @@ -107,7 +111,7 @@ where pub struct Extract where M: CompletionModel, - Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync, { extractor: Extractor, _in: std::marker::PhantomData, @@ -116,7 +120,7 @@ where impl Extract where M: CompletionModel, - Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync, { pub(crate) fn new(extractor: Extractor) -> Self { Self { @@ -129,8 +133,8 @@ where impl Op for Extract where M: CompletionModel, - Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - Input: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync, + Input: Into + WasmCompatSend + WasmCompatSync, { type Input = Input; type Output = Result; @@ -146,8 +150,8 @@ where pub fn extract(extractor: Extractor) -> Extract where M: CompletionModel, - Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - Input: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + WasmCompatSend + WasmCompatSync, + Input: Into + WasmCompatSend + WasmCompatSync, { Extract::new(extractor) } @@ -179,7 +183,7 @@ pub mod tests { pub struct MockIndex; impl VectorStoreIndex for MockIndex { - async fn top_n serde::Deserialize<'a> + std::marker::Send>( + async fn top_n serde::Deserialize<'a> + WasmCompatSend>( &self, _req: VectorSearchRequest, ) -> Result, VectorStoreError> { diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index fda072dc9..33916f7e2 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -1,24 +1,28 @@ -use std::future::Future; - +use crate::wasm_compat::*; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::join; use futures::stream; +use std::future::Future; // ================================================================ // Core Op trait // ================================================================ -pub trait Op: Send + Sync { - type Input: Send + Sync; - type Output: Send + Sync; +pub trait Op: WasmCompatSend + WasmCompatSync { + type Input: WasmCompatSend + WasmCompatSync; + type Output: WasmCompatSend + WasmCompatSync; - fn call(&self, input: Self::Input) -> impl Future + Send; + fn call(&self, input: Self::Input) -> impl Future + WasmCompatSend; /// Execute the current pipeline with the given inputs. `n` is the number of concurrent /// inputs that will be processed concurrently. - fn batch_call(&self, n: usize, input: I) -> impl Future> + Send + fn batch_call( + &self, + n: usize, + input: I, + ) -> impl Future> + WasmCompatSend where - I: IntoIterator + Send, - I::IntoIter: Send, + I: IntoIterator + WasmCompatSend, + I::IntoIter: WasmCompatSend, Self: Sized, { use futures::stream::StreamExt; @@ -47,8 +51,8 @@ pub trait Op: Send + Sync { /// ``` fn map(self, f: F) -> Sequential> where - F: Fn(Self::Output) -> Input + Send + Sync, - Input: Send + Sync, + F: Fn(Self::Output) -> Input + WasmCompatSend + WasmCompatSync, + Input: WasmCompatSend + WasmCompatSync, Self: Sized, { Sequential::new(self, Map::new(f)) @@ -73,9 +77,9 @@ pub trait Op: Send + Sync { /// ``` fn then(self, f: F) -> Sequential> where - F: Fn(Self::Output) -> Fut + Send + Sync, - Fut: Future + Send + Sync, - Fut::Output: Send + Sync, + F: Fn(Self::Output) -> Fut + Send + WasmCompatSync, + Fut: Future + WasmCompatSend + WasmCompatSync, + Fut::Output: WasmCompatSend + WasmCompatSync, Self: Sized, { Sequential::new(self, Then::new(f)) @@ -135,7 +139,7 @@ pub trait Op: Send + Sync { ) -> Sequential> where I: vector_store::VectorStoreIndex, - Input: Send + Sync + for<'a> serde::Deserialize<'a>, + Input: WasmCompatSend + WasmCompatSync + for<'a> serde::Deserialize<'a>, Self::Output: Into, Self: Sized, { @@ -208,9 +212,8 @@ where } } -use crate::{completion, vector_store}; - use super::agent_ops::{Lookup, Prompt}; +use crate::{completion, vector_store}; // ================================================================ // Core Op implementations @@ -231,9 +234,9 @@ impl Map { impl Op for Map where - F: Fn(Input) -> Output + Send + Sync, - Input: Send + Sync, - Output: Send + Sync, + F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync, + Input: WasmCompatSend + WasmCompatSync, + Output: WasmCompatSend + WasmCompatSync, { type Input = Input; type Output = Output; @@ -246,9 +249,9 @@ where pub fn map(f: F) -> Map where - F: Fn(Input) -> Output + Send + Sync, - Input: Send + Sync, - Output: Send + Sync, + F: Fn(Input) -> Output + WasmCompatSend + WasmCompatSync, + Input: WasmCompatSend + WasmCompatSync, + Output: WasmCompatSend + WasmCompatSync, { Map::new(f) } @@ -267,7 +270,7 @@ impl Passthrough { impl Op for Passthrough where - T: Send + Sync, + T: WasmCompatSend + WasmCompatSync, { type Input = T; type Output = T; @@ -279,7 +282,7 @@ where pub fn passthrough() -> Passthrough where - T: Send + Sync, + T: WasmCompatSend + WasmCompatSync, { Passthrough::new() } @@ -300,10 +303,10 @@ impl Then { impl Op for Then where - F: Fn(Input) -> Fut + Send + Sync, - Input: Send + Sync, - Fut: Future + Send, - Fut::Output: Send + Sync, + F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync, + Input: WasmCompatSend + WasmCompatSync, + Fut: Future + WasmCompatSend, + Fut::Output: WasmCompatSend + WasmCompatSync, { type Input = Input; type Output = Fut::Output; @@ -316,10 +319,10 @@ where pub fn then(f: F) -> Then where - F: Fn(Input) -> Fut + Send + Sync, - Input: Send + Sync, - Fut: Future + Send, - Fut::Output: Send + Sync, + F: Fn(Input) -> Fut + WasmCompatSend + WasmCompatSync, + Input: WasmCompatSend + WasmCompatSync, + Fut: Future + WasmCompatSend, + Fut::Output: WasmCompatSend + WasmCompatSync, { Then::new(f) } diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index 8a629b98a..523f6c40d 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -4,21 +4,23 @@ use futures::stream; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; +use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; + use super::op::{self}; // ================================================================ // Core TryOp trait // ================================================================ -pub trait TryOp: Send + Sync { - type Input: Send + Sync; - type Output: Send + Sync; - type Error: Send + Sync; +pub trait TryOp: WasmCompatSend + WasmCompatSync { + type Input: WasmCompatSend + WasmCompatSync; + type Output: WasmCompatSend + WasmCompatSync; + type Error: WasmCompatSend + WasmCompatSync; /// Execute the current op with the given input. fn try_call( &self, input: Self::Input, - ) -> impl Future> + Send; + ) -> impl Future> + WasmCompatSend; /// Execute the current op with the given inputs. `n` is the number of concurrent /// inputs that will be processed concurrently. @@ -40,10 +42,10 @@ pub trait TryOp: Send + Sync { &self, n: usize, input: I, - ) -> impl Future, Self::Error>> + Send + ) -> impl Future, Self::Error>> + WasmCompatSend where - I: IntoIterator + Send, - I::IntoIter: Send, + I: IntoIterator + WasmCompatSend, + I::IntoIter: WasmCompatSend, Self: Sized, { use stream::{StreamExt, TryStreamExt}; @@ -73,8 +75,8 @@ pub trait TryOp: Send + Sync { /// ``` fn map_ok(self, f: F) -> MapOk> where - F: Fn(Self::Output) -> Output + Send + Sync, - Output: Send + Sync, + F: Fn(Self::Output) -> Output + WasmCompatSend + WasmCompatSync, + Output: WasmCompatSend + WasmCompatSync, Self: Sized, { MapOk::new(self, op::Map::new(f)) @@ -96,8 +98,8 @@ pub trait TryOp: Send + Sync { /// ``` fn map_err(self, f: F) -> MapErr> where - F: Fn(Self::Error) -> E + Send + Sync, - E: Send + Sync, + F: Fn(Self::Error) -> E + WasmCompatSend + WasmCompatSync, + E: WasmCompatSend + WasmCompatSync, Self: Sized, { MapErr::new(self, op::Map::new(f)) @@ -120,9 +122,9 @@ pub trait TryOp: Send + Sync { /// ``` fn and_then(self, f: F) -> AndThen> where - F: Fn(Self::Output) -> Fut + Send + Sync, - Fut: Future> + Send + Sync, - Output: Send + Sync, + F: Fn(Self::Output) -> Fut + WasmCompatSend + WasmCompatSync, + Fut: Future> + WasmCompatSend + WasmCompatSync, + Output: WasmCompatSend + WasmCompatSync, Self: Sized, { AndThen::new(self, op::Then::new(f)) @@ -145,9 +147,9 @@ pub trait TryOp: Send + Sync { /// ``` fn or_else(self, f: F) -> OrElse> where - F: Fn(Self::Error) -> Fut + Send + Sync, - Fut: Future> + Send + Sync, - E: Send + Sync, + F: Fn(Self::Error) -> Fut + WasmCompatSend + WasmCompatSync, + Fut: Future> + WasmCompatSend + WasmCompatSync, + E: WasmCompatSend + WasmCompatSync, Self: Sized, { OrElse::new(self, op::Then::new(f)) @@ -191,8 +193,8 @@ pub trait TryOp: Send + Sync { impl TryOp for Op where Op: super::Op>, - T: Send + Sync, - E: Send + Sync, + T: WasmCompatSend + WasmCompatSync, + E: WasmCompatSend + WasmCompatSync, { type Input = Op::Input; type Output = T; diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index 905c06342..def9dbb28 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -8,6 +8,7 @@ use crate::{ message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning}, one_or_many::string_or_one_or_many, telemetry::{ProviderResponseExt, SpanCombinator}, + wasm_compat::*, }; use std::{convert::Infallible, str::FromStr}; @@ -645,7 +646,10 @@ impl TryFrom for message::Message { } #[derive(Clone)] -pub struct CompletionModel { +pub struct CompletionModel +where + T: WasmCompatSend, +{ pub(crate) client: Client, pub model: String, pub default_max_tokens: Option, @@ -728,7 +732,7 @@ impl TryFrom for ToolChoice { } impl completion::CompletionModel for CompletionModel where - T: HttpClientExt + Clone + Default, + T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static, { type Response = CompletionResponse; type StreamingResponse = StreamingCompletionResponse; diff --git a/rig-core/src/providers/anthropic/decoders/sse.rs b/rig-core/src/providers/anthropic/decoders/sse.rs index fdbbe5c02..26c2e1a14 100644 --- a/rig-core/src/providers/anthropic/decoders/sse.rs +++ b/rig-core/src/providers/anthropic/decoders/sse.rs @@ -1,7 +1,12 @@ +use crate::{ + if_not_wasm, if_wasm, + wasm_compat::{WasmCompatSend, WasmCompatSync}, +}; + use super::line::{self, LineDecoder}; use bytes::Bytes; use futures::{Stream, StreamExt, stream::BoxStream}; -use std::fmt::Debug; +use std::{fmt::Debug, pin::Pin}; use thiserror::Error; #[derive(Debug, Error)] @@ -182,14 +187,30 @@ fn extract_sse_chunk(buffer: &[u8]) -> Option<(Vec, Vec)> { Some((chunk, remaining)) } -pub fn from_response<'a, E>( - stream: BoxStream<'a, Result>, -) -> impl Stream> -where - E: Into>, -{ - iter_sse_messages(stream.map(|result| match result { - Ok(bytes) => Ok(bytes.to_vec()), - Err(e) => Err(std::io::Error::other(e)), - })) +if_wasm! { + pub fn from_response<'a, E>( + stream: Pin> + 'a>>, + ) -> impl Stream> + where + E: std::fmt::Display + WasmCompatSend + WasmCompatSync + 'static + { + iter_sse_messages(stream.map(|result| match result { + Ok(bytes) => Ok(bytes.to_vec()), + Err(e) => Err(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())), + })) + } +} + +if_not_wasm! { + pub fn from_response<'a, E>( + stream: BoxStream<'a, Result>, + ) -> impl Stream> + where + E: Into>, + { + iter_sse_messages(stream.map(|result| match result { + Ok(bytes) => Ok(bytes.to_vec()), + Err(e) => Err(std::io::Error::other(e)), + })) + } } diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index 5668c48ea..209037362 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -198,7 +198,7 @@ where .body(body) .map_err(http_client::Error::Protocol)?; - let response: http_client::StreamingResponse = self.client.send_streaming(req).await?; + let response = self.client.send_streaming(req).await?; if !response.status().is_success() { let mut stream = response.into_body(); diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index 91a1bd64a..dd7b8e2c7 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -3,6 +3,7 @@ use crate::{ client::{VerifyClient, VerifyError}, embeddings::EmbeddingsBuilder, http_client::{self, HttpClientExt}, + wasm_compat::*, }; use super::{CompletionModel, EmbeddingModel}; @@ -107,7 +108,7 @@ impl Client { impl Client where - T: HttpClientExt + Clone, + T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static, { fn req( &self, diff --git a/rig-core/src/providers/cohere/embeddings.rs b/rig-core/src/providers/cohere/embeddings.rs index b2a08b361..ef9d939d6 100644 --- a/rig-core/src/providers/cohere/embeddings.rs +++ b/rig-core/src/providers/cohere/embeddings.rs @@ -3,6 +3,7 @@ use super::{Client, client::ApiResponse}; use crate::{ embeddings::{self, EmbeddingError}, http_client::HttpClientExt, + wasm_compat::*, }; use serde::Deserialize; @@ -68,7 +69,7 @@ pub struct EmbeddingModel { impl embeddings::EmbeddingModel for EmbeddingModel where - T: HttpClientExt + Clone, + T: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static, { const MAX_DOCUMENTS: usize = 96; diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index e61cd4f01..26782a192 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -6,6 +6,7 @@ use crate::client::{ VerifyClient, VerifyError, impl_conversion_traits, }; use crate::http_client::{self, HttpClientExt}; +use crate::wasm_compat::*; use crate::{ Embed, embeddings::{self}, @@ -294,7 +295,7 @@ where impl VerifyClient for Client where - T: HttpClientExt + Clone + Debug + Default + 'static, + T: HttpClientExt + Clone + Debug + Default + WasmCompatSend + WasmCompatSync + 'static, Client: CompletionClient, { #[cfg_attr(feature = "worker", worker::send)] diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 1614fa6c4..096571c39 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -8,6 +8,7 @@ use serde_json::json; use crate::{ embeddings::{self, EmbeddingError}, http_client::HttpClientExt, + wasm_compat::{WasmCompatSend, WasmCompatSync}, }; use super::{Client, client::ApiResponse}; @@ -35,7 +36,7 @@ impl EmbeddingModel { impl embeddings::EmbeddingModel for EmbeddingModel where - T: Send + Sync + Clone + HttpClientExt, + T: Clone + HttpClientExt, { const MAX_DOCUMENTS: usize = 1024; @@ -50,7 +51,7 @@ where #[cfg_attr(feature = "worker", worker::send)] async fn embed_texts( &self, - documents: impl IntoIterator + Send, + documents: impl IntoIterator + WasmCompatSend, ) -> Result, EmbeddingError> { let documents: Vec = documents.into_iter().collect(); diff --git a/rig-core/src/providers/gemini/transcription.rs b/rig-core/src/providers/gemini/transcription.rs index 23a79ae08..50c36eb25 100644 --- a/rig-core/src/providers/gemini/transcription.rs +++ b/rig-core/src/providers/gemini/transcription.rs @@ -10,6 +10,7 @@ use crate::{ Blob, Content, GenerateContentRequest, GenerationConfig, Part, PartKind, Role, }, transcription::{self, TranscriptionError}, + wasm_compat::{WasmCompatSend, WasmCompatSync}, }; use super::{Client, completion::gemini_api_types::GenerateContentResponse}; @@ -39,7 +40,7 @@ impl TranscriptionModel { impl transcription::TranscriptionModel for TranscriptionModel where - T: HttpClientExt + Send + Sync + Clone, + T: HttpClientExt + WasmCompatSend + WasmCompatSync + Clone, { type Response = GenerateContentResponse; diff --git a/rig-core/src/providers/huggingface/transcription.rs b/rig-core/src/providers/huggingface/transcription.rs index b2618280d..ac901f635 100644 --- a/rig-core/src/providers/huggingface/transcription.rs +++ b/rig-core/src/providers/huggingface/transcription.rs @@ -3,6 +3,7 @@ use crate::providers::huggingface::Client; use crate::providers::huggingface::completion::ApiResponse; use crate::transcription; use crate::transcription::TranscriptionError; +use crate::wasm_compat::WasmCompatSync; use base64::Engine; use base64::prelude::BASE64_STANDARD; use serde::Deserialize; @@ -47,7 +48,7 @@ impl TranscriptionModel { } impl transcription::TranscriptionModel for TranscriptionModel where - T: HttpClientExt + Clone, + T: HttpClientExt + Clone + WasmCompatSync, { type Response = TranscriptionResponse; diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index fe1599ab4..7997fa371 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -125,7 +125,6 @@ where { pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - dbg!(&url); http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) @@ -133,7 +132,6 @@ where pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path).replace("//", "/"); - dbg!(&url); http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index 8fa0a17c0..bec31d024 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -16,6 +16,9 @@ use crate::completion::{ Message, Usage, }; use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction}; +#[cfg(not(target_arch = "wasm32"))] +use crate::wasm_compat::WasmBoxedFuture; +use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; use futures::stream::{AbortHandle, Abortable}; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; @@ -88,8 +91,7 @@ where } #[cfg(not(target_arch = "wasm32"))] -pub type StreamingResult = - Pin, CompletionError>> + Send>>; +pub type StreamingResult = WasmBoxedFuture, CompletionError>>; #[cfg(target_arch = "wasm32")] pub type StreamingResult = @@ -277,24 +279,27 @@ where pub trait StreamingPrompt where M: CompletionModel + 'static, - ::StreamingResponse: Send, + ::StreamingResponse: WasmCompatSend, R: Clone + Unpin + GetTokenUsage, { /// Stream a simple prompt to the model - fn stream_prompt(&self, prompt: impl Into + Send) -> StreamingPromptRequest; + fn stream_prompt( + &self, + prompt: impl Into + WasmCompatSend, + ) -> StreamingPromptRequest; } /// Trait for high-level streaming chat interface -pub trait StreamingChat: Send + Sync +pub trait StreamingChat: WasmCompatSend + WasmCompatSync where M: CompletionModel + 'static, - ::StreamingResponse: Send, + ::StreamingResponse: WasmCompatSend, R: Clone + Unpin + GetTokenUsage, { /// Stream a chat with history to the model fn stream_chat( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, ) -> StreamingPromptRequest; } @@ -304,7 +309,7 @@ pub trait StreamingCompletion { /// Generate a streaming completion from a request fn stream_completion( &self, - prompt: impl Into + Send, + prompt: impl Into + WasmCompatSend, chat_history: Vec, ) -> impl Future, CompletionError>>; } @@ -351,7 +356,7 @@ impl Stream for StreamingResultDyn { /// helper function to stream a completion request to stdout pub async fn stream_to_stdout( - agent: &Agent, + agent: &'static Agent, stream: &mut StreamingCompletionResponse, ) -> Result<(), std::io::Error> where diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 1d1833085..2c537722f 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -17,14 +17,22 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, embeddings::{embed::EmbedError, tool::ToolSchema}, + if_not_wasm, if_wasm, + wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}, }; #[derive(Debug, thiserror::Error)] pub enum ToolError { + #[cfg(not(target_family = "wasm"))] /// Error returned by the tool #[error("ToolCallError: {0}")] ToolCallError(#[from] Box), + #[cfg(target_family = "wasm")] + /// Error returned by the tool + #[error("ToolCallError: {0}")] + ToolCallError(#[from] Box), + #[error("JsonError: {0}")] JsonError(#[from] serde_json::Error), } @@ -84,14 +92,14 @@ pub enum ToolError { /// } /// } /// ``` -pub trait Tool: Sized + Send + Sync { +pub trait Tool: Sized + WasmCompatSend + WasmCompatSync { /// The name of the tool. This name should be unique. const NAME: &'static str; /// The error type of the tool. - type Error: std::error::Error + Send + Sync + 'static; + type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static; /// The arguments type of the tool. - type Args: for<'a> Deserialize<'a> + Send + Sync; + type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync; /// The output type of the tool. type Output: Serialize; @@ -102,7 +110,10 @@ pub trait Tool: Sized + Send + Sync { /// A method returning the tool definition. The user prompt can be used to /// tailor the definition to the specific use case. - fn definition(&self, _prompt: String) -> impl Future + Send + Sync; + fn definition( + &self, + _prompt: String, + ) -> impl Future + WasmCompatSend + WasmCompatSync; /// The tool execution method. /// Both the arguments and return value are a String since these values are meant to @@ -110,12 +121,12 @@ pub trait Tool: Sized + Send + Sync { fn call( &self, args: Self::Args, - ) -> impl Future> + Send; + ) -> impl Future> + WasmCompatSend; } /// Trait that represents an LLM tool that can be stored in a vector store and RAGged pub trait ToolEmbedding: Tool { - type InitError: std::error::Error + Send + Sync + 'static; + type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static; /// Type of the tool' context. This context will be saved and loaded from the /// vector store when ragging the tool. @@ -126,7 +137,7 @@ pub trait ToolEmbedding: Tool { /// Type of the tool's state. This state will be passed to the tool when initializing it. /// This state can be used to pass runtime arguments to the tool such as clients, /// API keys and other configuration. - type State: Send; + type State: WasmCompatSend; /// A method returning the documents that will be used as embeddings for the tool. /// This allows for a tool to be retrieved from multiple embedding "directions". @@ -141,18 +152,12 @@ pub trait ToolEmbedding: Tool { } /// Wrapper trait to allow for dynamic dispatch of simple tools -pub trait ToolDyn: Send + Sync { +pub trait ToolDyn: WasmCompatSend + WasmCompatSync { fn name(&self) -> String; - fn definition( - &self, - prompt: String, - ) -> Pin + Send + Sync + '_>>; + fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>; - fn call( - &self, - args: String, - ) -> Pin> + Send + '_>>; + fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result>; } impl ToolDyn for T { @@ -160,17 +165,11 @@ impl ToolDyn for T { self.name() } - fn definition( - &self, - prompt: String, - ) -> Pin + Send + Sync + '_>> { + fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> { Box::pin(::definition(self, prompt)) } - fn call( - &self, - args: String, - ) -> Pin> + Send + '_>> { + fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result> { Box::pin(async move { match serde_json::from_str(&args) { Ok(args) => ::call(self, args) @@ -244,10 +243,7 @@ pub mod rmcp { self.definition.name.to_string() } - fn definition( - &self, - _prompt: String, - ) -> Pin + Send + Sync + '_>> { + fn definition(&self, _prompt: String) -> WasmBoxedFuture { Box::pin(async move { ToolDefinition { name: self.definition.name.to_string(), @@ -263,10 +259,7 @@ pub mod rmcp { }) } - fn call( - &self, - args: String, - ) -> Pin> + Send + '_>> { + fn call(&self, args: String) -> WasmBoxedFuture> { let name = self.definition.name.clone(); let arguments = serde_json::from_str(&args).unwrap_or_default(); @@ -349,7 +342,7 @@ pub trait ToolEmbeddingDyn: ToolDyn { impl ToolEmbeddingDyn for T where - T: ToolEmbedding, + T: ToolEmbedding + 'static, { fn context(&self) -> serde_json::Result { serde_json::to_value(self.context()) @@ -457,7 +450,7 @@ impl ToolSet { } /// Call a tool with the given name and arguments - pub async fn call(&self, toolname: &str, args: String) -> Result { + pub async fn call<'a>(&'a self, toolname: &str, args: String) -> Result { if let Some(tool) = self.tools.get(toolname) { tracing::info!(target: "rig", "Calling tool {toolname} with args:\n{}", diff --git a/rig-core/src/transcription.rs b/rig-core/src/transcription.rs index 24932e820..0ffa3637d 100644 --- a/rig-core/src/transcription.rs +++ b/rig-core/src/transcription.rs @@ -3,6 +3,7 @@ //! handling transcription responses, and defining transcription models. use crate::client::transcription::TranscriptionModelHandle; +use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}; use crate::{http_client, json_utils}; use futures::future::BoxFuture; use std::sync::Arc; @@ -21,10 +22,16 @@ pub enum TranscriptionError { #[error("JsonError: {0}")] JsonError(#[from] serde_json::Error), + #[cfg(not(target_family = "wasm"))] /// Error building the transcription request #[error("RequestError: {0}")] RequestError(#[from] Box), + #[cfg(target_family = "wasm")] + /// Error building the transcription request + #[error("RequestError: {0}")] + RequestError(#[from] Box), + /// Error parsing the transcription response #[error("ResponseError: {0}")] ResponseError(String), @@ -53,7 +60,7 @@ where data: &[u8], ) -> impl std::future::Future< Output = Result, TranscriptionError>, - > + Send; + > + WasmCompatSend; } /// General transcription response struct that contains the transcription text @@ -66,9 +73,9 @@ pub struct TranscriptionResponse { /// Trait defining a transcription model that can be used to generate transcription requests. /// This trait is meant to be implemented by the user to define a custom transcription model, /// either from a third-party provider (e.g: OpenAI) or a local model. -pub trait TranscriptionModel: Clone + Send + Sync { +pub trait TranscriptionModel: Clone + WasmCompatSend + WasmCompatSync { /// The raw response type returned by the underlying model. - type Response: Sync + Send; + type Response: WasmCompatSend + WasmCompatSync; /// Generates a completion response for the given transcription model fn transcription( @@ -76,7 +83,7 @@ pub trait TranscriptionModel: Clone + Send + Sync { request: TranscriptionRequest, ) -> impl std::future::Future< Output = Result, TranscriptionError>, - > + Send; + > + WasmCompatSend; /// Generates a transcription request builder for the given `file` fn transcription_request(&self) -> TranscriptionRequestBuilder { @@ -84,11 +91,11 @@ pub trait TranscriptionModel: Clone + Send + Sync { } } -pub trait TranscriptionModelDyn: Send + Sync { +pub trait TranscriptionModelDyn: WasmCompatSend + WasmCompatSync { fn transcription( &self, request: TranscriptionRequest, - ) -> BoxFuture<'_, Result, TranscriptionError>>; + ) -> WasmBoxedFuture<'_, Result, TranscriptionError>>; fn transcription_request(&self) -> TranscriptionRequestBuilder>; } @@ -100,7 +107,7 @@ where fn transcription( &self, request: TranscriptionRequest, - ) -> BoxFuture<'_, Result, TranscriptionError>> { + ) -> WasmBoxedFuture<'_, Result, TranscriptionError>> { Box::pin(async move { let resp = self.transcription(request).await?; diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 900253ec6..17e1035f8 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -9,6 +9,7 @@ use crate::{ completion::ToolDefinition, embeddings::{Embedding, EmbeddingError}, tool::Tool, + wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}, }; pub mod in_memory_store; @@ -23,9 +24,14 @@ pub enum VectorStoreError { #[error("Json error: {0}")] JsonError(#[from] serde_json::Error), + #[cfg(not(target_family = "wasm"))] #[error("Datastore error: {0}")] DatastoreError(#[from] Box), + #[cfg(target_family = "wasm")] + #[error("Datastore error: {0}")] + DatastoreError(#[from] Box), + #[error("Missing Id: {0}")] MissingIdError(String), @@ -40,47 +46,45 @@ pub enum VectorStoreError { } /// Trait for inserting documents into a vector store. -pub trait InsertDocuments: Send + Sync { +pub trait InsertDocuments: WasmCompatSend + WasmCompatSync { /// Insert documents into the vector store. /// - fn insert_documents( + fn insert_documents( &self, documents: Vec<(Doc, OneOrMany)>, - ) -> impl std::future::Future> + Send; + ) -> impl std::future::Future> + WasmCompatSend; } /// Trait for vector store indexes -pub trait VectorStoreIndex: Send + Sync { +pub trait VectorStoreIndex: WasmCompatSend + WasmCompatSync { /// Get the top n documents based on the distance to the given query. /// The result is a list of tuples of the form (score, id, document) - fn top_n Deserialize<'a> + Send>( + fn top_n Deserialize<'a> + WasmCompatSend>( &self, req: VectorSearchRequest, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + + WasmCompatSend; /// Same as `top_n` but returns the document ids only. fn top_n_ids( &self, req: VectorSearchRequest, - ) -> impl std::future::Future, VectorStoreError>> + Send; + ) -> impl std::future::Future, VectorStoreError>> + WasmCompatSend; } pub type TopNResults = Result, VectorStoreError>; -pub trait VectorStoreIndexDyn: Send + Sync { - fn top_n<'a>(&'a self, req: VectorSearchRequest) -> BoxFuture<'a, TopNResults>; +pub trait VectorStoreIndexDyn: WasmCompatSend + WasmCompatSync { + fn top_n<'a>(&'a self, req: VectorSearchRequest) -> WasmBoxedFuture<'a, TopNResults>; fn top_n_ids<'a>( &'a self, req: VectorSearchRequest, - ) -> BoxFuture<'a, Result, VectorStoreError>>; + ) -> WasmBoxedFuture<'a, Result, VectorStoreError>>; } impl VectorStoreIndexDyn for I { - fn top_n<'a>( - &'a self, - req: VectorSearchRequest, - ) -> BoxFuture<'a, Result, VectorStoreError>> { + fn top_n<'a>(&'a self, req: VectorSearchRequest) -> WasmBoxedFuture<'a, TopNResults> { Box::pin(async move { Ok(self .top_n::(req) @@ -94,7 +98,7 @@ impl VectorStoreIndexDyn for I { fn top_n_ids<'a>( &'a self, req: VectorSearchRequest, - ) -> BoxFuture<'a, Result, VectorStoreError>> { + ) -> WasmBoxedFuture<'a, Result, VectorStoreError>> { Box::pin(self.top_n_ids(req)) } } diff --git a/rig-core/src/wasm_compat.rs b/rig-core/src/wasm_compat.rs new file mode 100644 index 000000000..c854b0064 --- /dev/null +++ b/rig-core/src/wasm_compat.rs @@ -0,0 +1,45 @@ +use std::pin::Pin; + +#[cfg(not(target_arch = "wasm32"))] +pub trait WasmCompatSend: Send {} +#[cfg(target_arch = "wasm32")] +pub trait WasmCompatSend {} + +#[cfg(not(target_arch = "wasm32"))] +impl WasmCompatSend for T where T: Send {} +#[cfg(target_arch = "wasm32")] +impl WasmCompatSend for T {} + +#[cfg(not(target_arch = "wasm32"))] +pub trait WasmCompatSync: Sync {} +#[cfg(target_arch = "wasm32")] +pub trait WasmCompatSync {} + +#[cfg(not(target_arch = "wasm32"))] +impl WasmCompatSync for T where T: Sync {} +#[cfg(target_arch = "wasm32")] +impl WasmCompatSync for T {} + +#[cfg(not(target_family = "wasm"))] +pub type WasmBoxedFuture<'a, T> = Pin> + Send + 'a>; + +#[cfg(target_family = "wasm")] +pub type WasmBoxedFuture<'a, T> = Pin + 'a>>; + +#[macro_export] +macro_rules! if_wasm { + ($($tokens:tt)*) => { + #[cfg(target_family = "wasm")] + $($tokens)* + + }; +} + +#[macro_export] +macro_rules! if_not_wasm { + ($($tokens:tt)*) => { + #[cfg(not(target_family = "wasm"))] + $($tokens)* + + }; +} From a183d8309336292f061a902d3bdac49825ead223 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 15:58:09 -0400 Subject: [PATCH 14/20] Clippy --- rig-core/src/agent/prompt_request/mod.rs | 2 +- rig-core/src/completion/request.rs | 1 - rig-core/src/providers/anthropic/decoders/sse.rs | 2 +- rig-core/src/providers/gemini/embedding.rs | 2 +- rig-core/src/tool.rs | 3 +-- rig-core/src/transcription.rs | 1 - rig-core/src/vector_store/mod.rs | 1 - 7 files changed, 4 insertions(+), 8 deletions(-) diff --git a/rig-core/src/agent/prompt_request/mod.rs b/rig-core/src/agent/prompt_request/mod.rs index e03dc55a5..c3d7f604d 100644 --- a/rig-core/src/agent/prompt_request/mod.rs +++ b/rig-core/src/agent/prompt_request/mod.rs @@ -7,7 +7,7 @@ use std::{ }; use tracing::{Instrument, span::Id}; -use futures::{FutureExt, StreamExt, future::BoxFuture, stream}; +use futures::{StreamExt, stream}; use tracing::info_span; use crate::{ diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index f3654fc83..6b32bb9d7 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -74,7 +74,6 @@ use crate::{ message::{Message, UserContent}, tool::ToolSetError, }; -use futures::future::BoxFuture; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::HashMap; diff --git a/rig-core/src/providers/anthropic/decoders/sse.rs b/rig-core/src/providers/anthropic/decoders/sse.rs index 26c2e1a14..53c862371 100644 --- a/rig-core/src/providers/anthropic/decoders/sse.rs +++ b/rig-core/src/providers/anthropic/decoders/sse.rs @@ -5,7 +5,7 @@ use crate::{ use super::line::{self, LineDecoder}; use bytes::Bytes; -use futures::{Stream, StreamExt, stream::BoxStream}; +use futures::{Stream, StreamExt}; use std::{fmt::Debug, pin::Pin}; use thiserror::Error; diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 096571c39..47d453a93 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -8,7 +8,7 @@ use serde_json::json; use crate::{ embeddings::{self, EmbeddingError}, http_client::HttpClientExt, - wasm_compat::{WasmCompatSend, WasmCompatSync}, + wasm_compat::WasmCompatSend, }; use super::{Client, client::ApiResponse}; diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 2c537722f..e12860054 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -9,7 +9,7 @@ //! The [ToolSet] struct is a collection of tools that can be used by an [Agent](crate::agent::Agent) //! and optionally RAGged. -use std::{collections::HashMap, pin::Pin}; +use std::collections::HashMap; use futures::Future; use serde::{Deserialize, Serialize}; @@ -17,7 +17,6 @@ use serde::{Deserialize, Serialize}; use crate::{ completion::{self, ToolDefinition}, embeddings::{embed::EmbedError, tool::ToolSchema}, - if_not_wasm, if_wasm, wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}, }; diff --git a/rig-core/src/transcription.rs b/rig-core/src/transcription.rs index 0ffa3637d..eb68abdc2 100644 --- a/rig-core/src/transcription.rs +++ b/rig-core/src/transcription.rs @@ -5,7 +5,6 @@ use crate::client::transcription::TranscriptionModelHandle; use crate::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync}; use crate::{http_client, json_utils}; -use futures::future::BoxFuture; use std::sync::Arc; use std::{fs, path::Path}; use thiserror::Error; diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 17e1035f8..0dc8bdb09 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -1,4 +1,3 @@ -use futures::future::BoxFuture; pub use request::VectorSearchRequest; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; From 5c2162fad4a0b8d4dcd252b5fd22c701fe86f0fe Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 17:00:36 -0400 Subject: [PATCH 15/20] Fix native errors caused by wasm compat --- rig-core/src/completion/request.rs | 8 +- rig-core/src/embeddings/embedding.rs | 121 +++++------------- rig-core/src/http_client.rs | 7 +- .../src/providers/anthropic/decoders/sse.rs | 18 +-- rig-core/src/providers/anthropic/streaming.rs | 2 +- rig-core/src/providers/cohere/streaming.rs | 8 +- rig-core/src/providers/deepseek.rs | 6 +- rig-core/src/providers/gemini/streaming.rs | 8 +- rig-core/src/providers/groq.rs | 6 +- rig-core/src/providers/mistral/completion.rs | 6 +- rig-core/src/providers/ollama.rs | 8 +- .../providers/openai/completion/streaming.rs | 8 +- .../openai/responses_api/streaming.rs | 8 +- .../src/providers/openrouter/streaming.rs | 16 ++- rig-core/src/streaming.rs | 6 +- rig-core/src/tool.rs | 2 +- rig-core/src/wasm_compat.rs | 6 +- 17 files changed, 103 insertions(+), 141 deletions(-) diff --git a/rig-core/src/completion/request.rs b/rig-core/src/completion/request.rs index 6b32bb9d7..d643e1333 100644 --- a/rig-core/src/completion/request.rs +++ b/rig-core/src/completion/request.rs @@ -99,7 +99,7 @@ pub enum CompletionError { #[cfg(not(target_family = "wasm"))] /// Error building the completion request #[error("RequestError: {0}")] - RequestError(#[from] Box), + RequestError(#[from] Box), #[cfg(target_family = "wasm")] /// Error building the completion request @@ -399,11 +399,11 @@ where let resp = self.stream(request).await?; let inner = resp.inner; - let stream = Box::pin(streaming::StreamingResultDyn { + let stream = streaming::StreamingResultDyn { inner: Box::pin(inner), - }); + }; - Ok(StreamingCompletionResponse::stream(stream)) + Ok(StreamingCompletionResponse::stream(Box::pin(stream))) }) } diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index e3e76513a..2513d2495 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -6,15 +6,9 @@ //! Finally, the module defines the [EmbeddingError] enum, which represents various errors that //! can occur during embedding generation or processing. -use crate::{http_client, if_not_wasm, if_wasm, wasm_compat::*}; +use crate::wasm_compat::WasmBoxedFuture; +use crate::{http_client, wasm_compat::*}; use serde::{Deserialize, Serialize}; -if_wasm! { - use futures::future::LocalBoxFuture; -} - -if_not_wasm! { - use futures::future::BoxFuture; -} #[derive(Debug, thiserror::Error)] pub enum EmbeddingError { @@ -77,94 +71,43 @@ pub trait EmbeddingModel: Clone + WasmCompatSend + WasmCompatSync { } } -if_wasm! { - pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync { - fn max_documents(&self) -> usize; - fn ndims(&self) -> usize; - fn embed_text<'a>( - &'a self, - text: &'a str, - ) -> LocalBoxFuture<'a, Result>; - fn embed_texts( - &self, - texts: Vec, - ) -> LocalBoxFuture<'_, Result, EmbeddingError>>; - } - +pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync { + fn max_documents(&self) -> usize; + fn ndims(&self) -> usize; + fn embed_text<'a>( + &'a self, + text: &'a str, + ) -> WasmBoxedFuture<'a, Result>; + fn embed_texts( + &self, + texts: Vec, + ) -> WasmBoxedFuture<'_, Result, EmbeddingError>>; } -if_wasm! { - impl EmbeddingModelDyn for T - where - T: EmbeddingModel + WasmCompatSend + WasmCompatSync, - { - fn max_documents(&self) -> usize { - T::MAX_DOCUMENTS - } - - fn ndims(&self) -> usize { - self.ndims() - } - - fn embed_text<'a>( - &'a self, - text: &'a str, - ) -> LocalBoxFuture<'a, Result> { - Box::pin(self.embed_text(text)) - } - - fn embed_texts( - &self, - texts: Vec, - ) -> LocalBoxFuture<'_, Result, EmbeddingError>> { - Box::pin(self.embed_texts(texts.into_iter().collect::>())) - } +impl EmbeddingModelDyn for T +where + T: EmbeddingModel + WasmCompatSend + WasmCompatSync, +{ + fn max_documents(&self) -> usize { + T::MAX_DOCUMENTS } -} -if_not_wasm! { - pub trait EmbeddingModelDyn: WasmCompatSend + WasmCompatSync { - fn max_documents(&self) -> usize; - fn ndims(&self) -> usize; - fn embed_text<'a>( - &'a self, - text: &'a str, - ) -> LocalBoxFuture<'a, Result>; - fn embed_texts( - &self, - texts: Vec, - ) -> BoxFuture<'_, Result, EmbeddingError>>; + fn ndims(&self) -> usize { + self.ndims() } + fn embed_text<'a>( + &'a self, + text: &'a str, + ) -> WasmBoxedFuture<'a, Result> { + Box::pin(self.embed_text(text)) + } -} - -if_not_wasm! { - impl EmbeddingModelDyn for T - where - T: EmbeddingModel + WasmCompatSend + WasmCompatSync, - { - fn max_documents(&self) -> usize { - T::MAX_DOCUMENTS - } - - fn ndims(&self) -> usize { - self.ndims() - } - - fn embed_text<'a>( - &'a self, - text: &'a str, - ) -> LocalBoxFuture<'a, Result> { - Box::pin(self.embed_text(text)) - } - - fn embed_texts( - &self, - texts: Vec, - ) -> LocalBoxFuture<'_, Result, EmbeddingError>> { - Box::pin(self.embed_texts(texts.into_iter().collect::>())) - } + fn embed_texts( + &self, + texts: Vec, + ) -> WasmBoxedFuture<'_, Result, EmbeddingError>> { + Box::pin(self.embed_texts(texts.into_iter().collect::>())) } } diff --git a/rig-core/src/http_client.rs b/rig-core/src/http_client.rs index 6e6e2d744..6d4eda8eb 100644 --- a/rig-core/src/http_client.rs +++ b/rig-core/src/http_client.rs @@ -1,11 +1,16 @@ +use crate::if_wasm; use bytes::Bytes; #[cfg(not(target_family = "wasm"))] use futures::stream::BoxStream; +#[cfg(target_family = "wasm")] use futures::stream::Stream; pub use http::{HeaderMap, HeaderValue, Method, Request, Response, Uri, request::Builder}; use reqwest::Body; use std::future::Future; -use std::pin::Pin; + +if_wasm! { + use std::pin::Pin; +} use crate::wasm_compat::*; diff --git a/rig-core/src/providers/anthropic/decoders/sse.rs b/rig-core/src/providers/anthropic/decoders/sse.rs index 53c862371..6196b2450 100644 --- a/rig-core/src/providers/anthropic/decoders/sse.rs +++ b/rig-core/src/providers/anthropic/decoders/sse.rs @@ -1,13 +1,15 @@ -use crate::{ - if_not_wasm, if_wasm, - wasm_compat::{WasmCompatSend, WasmCompatSync}, -}; - use super::line::{self, LineDecoder}; +use crate::{if_not_wasm, if_wasm}; use bytes::Bytes; use futures::{Stream, StreamExt}; -use std::{fmt::Debug, pin::Pin}; +use std::fmt::Debug; use thiserror::Error; +if_not_wasm! { + use futures::stream::BoxStream; +} +if_wasm! { + use std::pin::Pin; +} #[derive(Debug, Error)] pub enum SSEDecoderError { @@ -192,7 +194,7 @@ if_wasm! { stream: Pin> + 'a>>, ) -> impl Stream> where - E: std::fmt::Display + WasmCompatSend + WasmCompatSync + 'static + E: std::fmt::Display + 'static { iter_sse_messages(stream.map(|result| match result { Ok(bytes) => Ok(bytes.to_vec()), @@ -206,7 +208,7 @@ if_not_wasm! { stream: BoxStream<'a, Result>, ) -> impl Stream> where - E: Into>, + E: Into> { iter_sse_messages(stream.map(|result| match result { Ok(bytes) => Ok(bytes.to_vec()), diff --git a/rig-core/src/providers/anthropic/streaming.rs b/rig-core/src/providers/anthropic/streaming.rs index 209037362..d0f62cf57 100644 --- a/rig-core/src/providers/anthropic/streaming.rs +++ b/rig-core/src/providers/anthropic/streaming.rs @@ -208,7 +208,7 @@ where break; }; - let chunk = chunk?; + let chunk: Vec = chunk?.into(); let str = String::from_utf8_lossy(&chunk); diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index f62ca4bba..7461a75b6 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -129,7 +129,7 @@ impl CompletionModel { .await .map_err(|e| CompletionError::ProviderError(e.to_string()))?; - let stream = Box::pin(stream! { + let stream = stream! { let mut current_tool_call: Option<(String, String, String)> = None; let mut text_response = String::new(); let mut tool_calls = Vec::new(); @@ -242,8 +242,10 @@ impl CompletionModel { } event_source.close(); - }.instrument(span)); + }.instrument(span); - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } } diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 09040ff81..203bf37f6 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -796,7 +796,7 @@ pub async fn send_compatible_streaming_request( .eventsource() .expect("Cloning request must succeed"); - let stream = Box::pin(stream! { + let stream = stream! { let mut final_usage = Usage::new(); let mut text_response = String::new(); let mut calls: HashMap = HashMap::new(); @@ -930,10 +930,10 @@ pub async fn send_compatible_streaming_request( yield Ok(crate::streaming::RawStreamingChoice::FinalResponse( StreamingCompletionResponse { usage: final_usage.clone() } )); - }); + }; Ok(crate::streaming::StreamingCompletionResponse::stream( - stream, + Box::pin(stream), )) } diff --git a/rig-core/src/providers/gemini/streaming.rs b/rig-core/src/providers/gemini/streaming.rs index d0ba60433..99503360e 100644 --- a/rig-core/src/providers/gemini/streaming.rs +++ b/rig-core/src/providers/gemini/streaming.rs @@ -113,7 +113,7 @@ impl CompletionModel { .eventsource() .expect("Cloning request must always succeed"); - let stream = Box::pin(stream! { + let stream = stream! { let mut text_response = String::new(); let mut model_outputs: Vec = Vec::new(); while let Some(event_result) = event_source.next().await { @@ -202,8 +202,10 @@ impl CompletionModel { // Ensure event source is closed when stream ends event_source.close(); - }); + }; - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } } diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 47868814d..4f04aa555 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -729,7 +729,7 @@ pub async fn send_compatible_streaming_request( .eventsource() .expect("Cloning request must succeed"); - let stream = Box::pin(stream! { + let stream = stream! { let span = tracing::Span::current(); let mut final_usage = Usage { prompt_tokens: 0, @@ -873,9 +873,9 @@ pub async fn send_compatible_streaming_request( yield Ok(crate::streaming::RawStreamingChoice::FinalResponse( StreamingCompletionResponse { usage: final_usage.clone() } )); - }.instrument(span)); + }.instrument(span); Ok(crate::streaming::StreamingCompletionResponse::stream( - stream, + Box::pin(stream), )) } diff --git a/rig-core/src/providers/mistral/completion.rs b/rig-core/src/providers/mistral/completion.rs index 4b1c5c534..2e6fbf746 100644 --- a/rig-core/src/providers/mistral/completion.rs +++ b/rig-core/src/providers/mistral/completion.rs @@ -559,7 +559,7 @@ where ) -> Result, CompletionError> { let resp = self.completion(request).await?; - let stream = Box::pin(stream! { + let stream = stream! { for c in resp.choice.clone() { match c { message::AssistantContent::Text(t) => { @@ -580,9 +580,9 @@ where } yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone())); - }); + }; - Ok(StreamingCompletionResponse::stream(stream)) + Ok(StreamingCompletionResponse::stream(Box::pin(stream))) } } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 07a378983..8f8675bc0 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -683,7 +683,7 @@ impl completion::CompletionModel for CompletionModel { )); } - let stream = Box::pin(try_stream! { + let stream = try_stream! { let span = tracing::Span::current(); let mut byte_stream = response.bytes_stream(); let mut tool_calls_final = Vec::new(); @@ -743,9 +743,11 @@ impl completion::CompletionModel for CompletionModel { } } } - }.instrument(span)); + }.instrument(span); - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } } diff --git a/rig-core/src/providers/openai/completion/streaming.rs b/rig-core/src/providers/openai/completion/streaming.rs index a5f3473ef..4c67d1d62 100644 --- a/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig-core/src/providers/openai/completion/streaming.rs @@ -119,7 +119,7 @@ pub async fn send_compatible_streaming_request( .eventsource() .expect("Cloning request must always succeed"); - let stream = Box::pin(stream! { + let stream = stream! { let span = tracing::Span::current(); let mut final_usage = Usage::new(); @@ -266,7 +266,9 @@ pub async fn send_compatible_streaming_request( yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { usage: final_usage.clone() })); - }.instrument(span)); + }.instrument(span); - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } diff --git a/rig-core/src/providers/openai/responses_api/streaming.rs b/rig-core/src/providers/openai/responses_api/streaming.rs index d952ac674..379f40329 100644 --- a/rig-core/src/providers/openai/responses_api/streaming.rs +++ b/rig-core/src/providers/openai/responses_api/streaming.rs @@ -230,7 +230,7 @@ impl ResponsesCompletionModel { .eventsource() .expect("Cloning request must always succeed"); - let stream = Box::pin(stream! { + let stream = stream! { let mut final_usage = ResponsesUsage::new(); let mut tool_calls: Vec> = Vec::new(); @@ -331,8 +331,10 @@ impl ResponsesCompletionModel { yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse { usage: final_usage.clone() })); - }.instrument(span)); + }.instrument(span); - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } } diff --git a/rig-core/src/providers/openrouter/streaming.rs b/rig-core/src/providers/openrouter/streaming.rs index 8c8383027..cabf4e4e1 100644 --- a/rig-core/src/providers/openrouter/streaming.rs +++ b/rig-core/src/providers/openrouter/streaming.rs @@ -169,7 +169,7 @@ pub async fn send_streaming_request( } // Handle OpenAI Compatible SSE chunks - let stream = Box::pin(stream! { + let stream = stream! { let mut stream = response.bytes_stream(); let mut tool_calls = HashMap::new(); let mut partial_line = String::new(); @@ -346,9 +346,11 @@ pub async fn send_streaming_request( usage: final_usage.unwrap_or_default() })) - }); + }; - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } pub async fn send_streaming_request1( @@ -358,7 +360,7 @@ pub async fn send_streaming_request1( .eventsource() .expect("Cloning request must always succeed"); - let stream = Box::pin(stream! { + let stream = stream! { // Accumulate tool calls by index while streaming let mut tool_calls: HashMap = HashMap::new(); let mut final_usage = None; @@ -509,7 +511,9 @@ pub async fn send_streaming_request1( yield Ok(streaming::RawStreamingChoice::FinalResponse(FinalCompletionResponse { usage: final_usage.unwrap_or_default(), })); - }); + }; - Ok(streaming::StreamingCompletionResponse::stream(stream)) + Ok(streaming::StreamingCompletionResponse::stream(Box::pin( + stream, + ))) } diff --git a/rig-core/src/streaming.rs b/rig-core/src/streaming.rs index bec31d024..d884adb55 100644 --- a/rig-core/src/streaming.rs +++ b/rig-core/src/streaming.rs @@ -16,13 +16,10 @@ use crate::completion::{ Message, Usage, }; use crate::message::{AssistantContent, Reasoning, Text, ToolCall, ToolFunction}; -#[cfg(not(target_arch = "wasm32"))] -use crate::wasm_compat::WasmBoxedFuture; use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; use futures::stream::{AbortHandle, Abortable}; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; -use std::boxed::Box; use std::future::Future; use std::pin::Pin; use std::sync::atomic::AtomicBool; @@ -91,7 +88,8 @@ where } #[cfg(not(target_arch = "wasm32"))] -pub type StreamingResult = WasmBoxedFuture, CompletionError>>; +pub type StreamingResult = + Pin, CompletionError>> + Send>>; #[cfg(target_arch = "wasm32")] pub type StreamingResult = diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index e12860054..4703f8e4e 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -449,7 +449,7 @@ impl ToolSet { } /// Call a tool with the given name and arguments - pub async fn call<'a>(&'a self, toolname: &str, args: String) -> Result { + pub async fn call(&self, toolname: &str, args: String) -> Result { if let Some(tool) = self.tools.get(toolname) { tracing::info!(target: "rig", "Calling tool {toolname} with args:\n{}", diff --git a/rig-core/src/wasm_compat.rs b/rig-core/src/wasm_compat.rs index c854b0064..f4be78eac 100644 --- a/rig-core/src/wasm_compat.rs +++ b/rig-core/src/wasm_compat.rs @@ -21,7 +21,7 @@ impl WasmCompatSync for T where T: Sync {} impl WasmCompatSync for T {} #[cfg(not(target_family = "wasm"))] -pub type WasmBoxedFuture<'a, T> = Pin> + Send + 'a>; +pub type WasmBoxedFuture<'a, T> = Pin + Send + 'a>>; #[cfg(target_family = "wasm")] pub type WasmBoxedFuture<'a, T> = Pin + 'a>>; @@ -29,7 +29,7 @@ pub type WasmBoxedFuture<'a, T> = Pin + 'a>>; #[macro_export] macro_rules! if_wasm { ($($tokens:tt)*) => { - #[cfg(target_family = "wasm")] + #[cfg(target_arch = "wasm32")] $($tokens)* }; @@ -38,7 +38,7 @@ macro_rules! if_wasm { #[macro_export] macro_rules! if_not_wasm { ($($tokens:tt)*) => { - #[cfg(not(target_family = "wasm"))] + #[cfg(not(target_arch = "wasm32"))] $($tokens)* }; From 243e9f52092cbcfe0dd12bf73ebcdfa058036e63 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 17:03:20 -0400 Subject: [PATCH 16/20] Fix imports --- rig-core/src/tool.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 4703f8e4e..bc431644b 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -191,7 +191,7 @@ pub mod rmcp { use crate::tool::ToolError; use rmcp::model::RawContent; use std::borrow::Cow; - use std::pin::Pin; + use wasm_compat::WasmBoxedFuture; pub struct McpTool { definition: rmcp::model::Tool, From d783b92bfb92b0a0abb361b04170083facb0f13d Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 17:07:44 -0400 Subject: [PATCH 17/20] Fix imports 2: the sequel --- rig-core/src/tool.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index bc431644b..ddd2252a5 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -189,9 +189,9 @@ pub mod rmcp { use crate::completion::ToolDefinition; use crate::tool::ToolDyn; use crate::tool::ToolError; + use crate::wasm_compat::WasmBoxedFuture; use rmcp::model::RawContent; use std::borrow::Cow; - use wasm_compat::WasmBoxedFuture; pub struct McpTool { definition: rmcp::model::Tool, From 7e91c4c7855625c1748d10efea33b85709804c78 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 17:12:53 -0400 Subject: [PATCH 18/20] Local clippy OK --- rig-core/src/tool.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index ddd2252a5..45065c8e2 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -242,7 +242,7 @@ pub mod rmcp { self.definition.name.to_string() } - fn definition(&self, _prompt: String) -> WasmBoxedFuture { + fn definition(&self, _prompt: String) -> WasmBoxedFuture<'_, ToolDefinition> { Box::pin(async move { ToolDefinition { name: self.definition.name.to_string(), @@ -258,7 +258,7 @@ pub mod rmcp { }) } - fn call(&self, args: String) -> WasmBoxedFuture> { + fn call(&self, args: String) -> WasmBoxedFuture<'_, Result> { let name = self.definition.name.clone(); let arguments = serde_json::from_str(&args).unwrap_or_default(); From 55986abe2bc68a4f6ad61fffcdfa236da1f85274 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 17:19:34 -0400 Subject: [PATCH 19/20] Fix openai path handling --- rig-core/src/providers/openai/client.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index 7997fa371..929913049 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -124,14 +124,14 @@ where T: HttpClientExt, { pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); dbg!(&url); http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); dbg!(&url); http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) From e71a80d26d905976e8e18398baed22c2c464f585 Mon Sep 17 00:00:00 2001 From: Fay Carsons Date: Thu, 9 Oct 2025 17:54:23 -0400 Subject: [PATCH 20/20] Fix client URI handling --- rig-core/src/providers/anthropic/client.rs | 4 ++-- rig-core/src/providers/azure.rs | 7 ++++--- rig-core/src/providers/cohere/client.rs | 2 +- rig-core/src/providers/deepseek.rs | 2 +- rig-core/src/providers/galadriel.rs | 4 ++-- rig-core/src/providers/gemini/client.rs | 22 +++++++++++++++++---- rig-core/src/providers/groq.rs | 2 +- rig-core/src/providers/hyperbolic.rs | 2 +- rig-core/src/providers/mira.rs | 4 ++-- rig-core/src/providers/mistral/client.rs | 2 +- rig-core/src/providers/moonshot.rs | 4 ++-- rig-core/src/providers/ollama.rs | 4 ++-- rig-core/src/providers/openai/client.rs | 4 +--- rig-core/src/providers/openrouter/client.rs | 4 ++-- rig-core/src/providers/perplexity.rs | 2 +- rig-core/src/providers/together/client.rs | 2 +- rig-core/src/providers/voyageai.rs | 2 +- rig-core/src/providers/xai/client.rs | 4 ++-- rig-eternalai/src/providers/eternalai.rs | 2 +- 19 files changed, 46 insertions(+), 33 deletions(-) diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs index 4648d43fd..3035c3e7c 100644 --- a/rig-core/src/providers/anthropic/client.rs +++ b/rig-core/src/providers/anthropic/client.rs @@ -172,7 +172,7 @@ where } pub(crate) fn post(&self, path: &str) -> http_client::Builder { - let uri = format!("{}/{}", ANTHROPIC_API_BASE_URL, path).replace("//", "/"); + let uri = format!("{}/{}", self.base_url, path.trim_start_matches('/')); let mut headers = self.default_headers.clone(); @@ -193,7 +193,7 @@ where } pub(crate) fn get(&self, path: &str) -> http_client::Builder { - let uri = format!("{}/{}", self.base_url, path).replace("//", "/"); + let uri = format!("{}/{}", self.base_url, path.trim_start_matches('/')); let mut headers = self.default_headers.clone(); headers.insert( diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index 0ab3af8c6..66cf02fcc 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -186,9 +186,10 @@ where fn post_embedding(&self, deployment_id: &str) -> http_client::Builder { let url = format!( "{}/openai/deployments/{}/embeddings?api-version={}", - self.azure_endpoint, deployment_id, self.api_version - ) - .replace("//", "/"); + self.azure_endpoint, + deployment_id.trim_start_matches('/'), + self.api_version + ); self.post(url) } diff --git a/rig-core/src/providers/cohere/client.rs b/rig-core/src/providers/cohere/client.rs index dd7b8e2c7..4cd9c3cba 100644 --- a/rig-core/src/providers/cohere/client.rs +++ b/rig-core/src/providers/cohere/client.rs @@ -115,7 +115,7 @@ where method: http_client::Method, path: &str, ) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth( http_client::Builder::new().method(method).uri(url), diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 203bf37f6..43389ba15 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -134,7 +134,7 @@ where method: http_client::Method, path: &str, ) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth( http_client::Request::builder().method(method).uri(url), diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index 5ab3bd402..087168a4f 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -136,7 +136,7 @@ where T: HttpClientExt, { pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); let mut req = http_client::Request::post(url); @@ -161,7 +161,7 @@ where impl Client { fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); let mut req = self.http_client.post(url).bearer_auth(&self.api_key); if let Some(fine_tune_key) = self.fine_tune_api_key.clone() { diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index 26782a192..c567896c1 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -130,8 +130,12 @@ where impl Client { pub(crate) fn post_sse(&self, path: &str) -> reqwest::RequestBuilder { - let url = - format!("{}/{}?alt=sse&key={}", self.base_url, path, self.api_key).replace("//", "/"); + let url = format!( + "{}/{}?alt=sse&key={}", + self.base_url, + path.trim_start_matches('/'), + self.api_key + ); tracing::debug!("POST {}/{}?alt=sse&key={}", self.base_url, path, "****"); @@ -147,7 +151,12 @@ where { pub(crate) fn post(&self, path: &str) -> http_client::Builder { // API key gets inserted as query param - no need to add bearer auth or headers - let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); + let url = format!( + "{}/{}?key={}", + self.base_url, + path.trim_start_matches('/'), + self.api_key + ); tracing::debug!("POST {}/{}?key={}", self.base_url, path, "****"); let mut req = http_client::Request::post(url); @@ -161,7 +170,12 @@ where pub(crate) fn get(&self, path: &str) -> http_client::Builder { // API key gets inserted as query param - no need to add bearer auth or headers - let url = format!("{}/{}?key={}", self.base_url, path, self.api_key).replace("//", "/"); + let url = format!( + "{}/{}?key={}", + self.base_url, + path.trim_start_matches('/'), + self.api_key + ); tracing::debug!("GET {}/{}?key={}", self.base_url, path, "****"); diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 4f04aa555..d97f2e1ca 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -140,7 +140,7 @@ where method: http_client::Method, path: &str, ) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth( http_client::Builder::new().method(method).uri(url), diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index bc27a70dd..e94102f7f 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -130,7 +130,7 @@ where method: http_client::Method, path: &str, ) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth( http_client::Builder::new().method(method).uri(url), diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 24897e4ea..2dab1e7c9 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -251,7 +251,7 @@ where method: http_client::Method, path: &str, ) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); let mut req = http_client::Builder::new().method(method).uri(url); @@ -269,7 +269,7 @@ where impl Client { fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client .post(url) diff --git a/rig-core/src/providers/mistral/client.rs b/rig-core/src/providers/mistral/client.rs index 0e10cee31..a1a65146d 100644 --- a/rig-core/src/providers/mistral/client.rs +++ b/rig-core/src/providers/mistral/client.rs @@ -102,7 +102,7 @@ where T: HttpClientExt, { pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) } diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index ad99ac821..6a729e465 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -126,7 +126,7 @@ where method: http_client::Method, path: &str, ) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth( http_client::Builder::new().method(method).uri(url), @@ -141,7 +141,7 @@ where impl Client { fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url).bearer_auth(&self.api_key) } diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 8f8675bc0..475c2fe8d 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -150,7 +150,7 @@ where impl Client { fn req(&self, method: http_client::Method, path: &str) -> http_client::Builder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::Builder::new().method(method).uri(url) } @@ -165,7 +165,7 @@ impl Client { impl Client { fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url) } } diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index 929913049..463d21811 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -125,14 +125,12 @@ where { pub(crate) fn post(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); - dbg!(&url); http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key) } pub(crate) fn get(&self, path: &str) -> http_client::Result { let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); - dbg!(&url); http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) } @@ -151,7 +149,7 @@ where impl Client { pub(crate) fn post_reqwest(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url).bearer_auth(&self.api_key) } diff --git a/rig-core/src/providers/openrouter/client.rs b/rig-core/src/providers/openrouter/client.rs index f44446e0b..7f9da98fa 100644 --- a/rig-core/src/providers/openrouter/client.rs +++ b/rig-core/src/providers/openrouter/client.rs @@ -82,7 +82,7 @@ impl Client { } pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url).bearer_auth(&self.api_key) } @@ -117,7 +117,7 @@ where impl Client { pub(crate) fn get(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key) } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index b6a3ef015..0992985ff 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -121,7 +121,7 @@ where impl Client { fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url).bearer_auth(&self.api_key) } } diff --git a/rig-core/src/providers/together/client.rs b/rig-core/src/providers/together/client.rs index a673f530a..eb403ea87 100644 --- a/rig-core/src/providers/together/client.rs +++ b/rig-core/src/providers/together/client.rs @@ -113,7 +113,7 @@ where T: HttpClientExt, { pub(crate) fn post(&self, path: &str) -> http_client::Result { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); tracing::debug!("POST {}", url); diff --git a/rig-core/src/providers/voyageai.rs b/rig-core/src/providers/voyageai.rs index 9dd969016..8f1362a66 100644 --- a/rig-core/src/providers/voyageai.rs +++ b/rig-core/src/providers/voyageai.rs @@ -100,7 +100,7 @@ where impl Client { pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url).bearer_auth(&self.api_key) } } diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs index b45a25112..f772da940 100644 --- a/rig-core/src/providers/xai/client.rs +++ b/rig-core/src/providers/xai/client.rs @@ -109,7 +109,7 @@ where impl Client { pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); tracing::debug!("POST {}", url); @@ -120,7 +120,7 @@ impl Client { } pub(crate) fn reqwest_get(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); tracing::debug!("GET {}", url); diff --git a/rig-eternalai/src/providers/eternalai.rs b/rig-eternalai/src/providers/eternalai.rs index 3a95082a0..2d3cb0820 100644 --- a/rig-eternalai/src/providers/eternalai.rs +++ b/rig-eternalai/src/providers/eternalai.rs @@ -119,7 +119,7 @@ impl Client { } pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder { - let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + let url = format!("{}/{}", self.base_url, path.trim_start_matches('/')); self.http_client.post(url).bearer_auth(&self.api_key) }