diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index f3b59defc..a83efd8a2 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -555,9 +555,6 @@ impl EmbeddingModel { } } -// ================================================================ -// Azure OpenAI Completion API -// ================================================================ /// `o1` completion model pub const O1: &str = "o1"; /// `o1-preview` completion model diff --git a/rig-core/src/providers/azure_ai_foundry/client.rs b/rig-core/src/providers/azure_ai_foundry/client.rs new file mode 100644 index 000000000..f5e12c3f7 --- /dev/null +++ b/rig-core/src/providers/azure_ai_foundry/client.rs @@ -0,0 +1,300 @@ +use bytes::Bytes; +use serde::Deserialize; + +use crate::{ + client::{CompletionClient, EmbeddingsClient, ProviderClient}, + completion::GetTokenUsage, + http_client::{self, HttpClientExt}, + impl_conversion_traits, + providers::azure_ai_foundry::{completion::CompletionModel, embedding::EmbeddingModel}, +}; + +pub const DEFAULT_API_VERSION: &str = "2024-10-21"; + +pub struct ClientBuilder<'a, T = reqwest::Client> { + api_key: &'a str, + api_version: Option<&'a str>, + azure_endpoint: &'a str, + http_client: T, +} + +impl<'a, T> ClientBuilder<'a, T> +where + T: Default, +{ + pub fn new(api_key: &'a str, endpoint: &'a str) -> Self { + Self { + api_key, + api_version: None, + azure_endpoint: endpoint, + http_client: Default::default(), + } + } +} + +impl<'a, T> ClientBuilder<'a, T> { + pub fn new_with_client(api_key: &'a str, azure_endpoint: &'a str, http_client: T) -> Self { + Self { + api_key, + api_version: None, + azure_endpoint, + http_client, + } + } + + /// API version to use (e.g., "2024-10-21" for GA, "2024-10-01-preview" for preview) + pub fn api_version(mut self, api_version: &'a str) -> Self { + self.api_version = Some(api_version); + self + } + + /// Azure OpenAI endpoint URL, for example: https://{your-resource-name}.services.ai.azure.com + /// SAFETY: Don't add a forward slash on the end of the URL + pub fn azure_endpoint(mut self, azure_endpoint: &'a str) -> Self { + self.azure_endpoint = azure_endpoint; + self + } + + pub fn with_client(self, http_client: U) -> ClientBuilder<'a, U> { + ClientBuilder { + api_key: self.api_key, + api_version: self.api_version, + azure_endpoint: self.azure_endpoint, + http_client, + } + } + + pub fn build(self) -> Client { + let api_version = self.api_version.unwrap_or(DEFAULT_API_VERSION); + + Client { + api_version: api_version.to_string(), + azure_endpoint: self.azure_endpoint.to_string(), + api_key: self.api_key.to_string(), + http_client: self.http_client, + } + } +} + +#[derive(Clone)] +pub struct Client { + api_version: String, + azure_endpoint: String, + api_key: String, + pub http_client: T, +} + +impl std::fmt::Debug for Client +where + T: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("azure_endpoint", &self.azure_endpoint) + .field("http_client", &self.http_client) + .field("api_key", &"") + .field("api_version", &self.api_version) + .finish() + } +} + +impl Client { + /// Create a new Azure AI Foundry client builder. + /// + /// # Example + /// ``` + /// use rig::providers::azure_ai_foundry::{ClientBuilder, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::builder("your-azure-api-key", "https://{your-resource-name}.services.ai.azure.com") + /// .build() + /// ``` + pub fn builder<'a>(api_key: &'a str, endpoint: &'a str) -> ClientBuilder<'a, reqwest::Client> { + ClientBuilder::new(api_key, endpoint) + } + + /// Creates a new Azure OpenAI client. For more control, use the `builder` method. + pub fn new(api_key: &str, endpoint: &str) -> Self { + Self::builder(api_key, endpoint).build() + } + + pub fn from_env() -> Self { + ::from_env() + } +} + +impl Client +where + T: HttpClientExt, +{ + pub fn post(&self, url: String) -> http_client::Builder { + http_client::Request::post(url).header("api-key", &self.api_key) + } + + pub fn post_chat_completion(&self) -> http_client::Builder { + let url = format!( + "{}/models/completions?api-version={}", + self.azure_endpoint, self.api_version + ); + + self.post(url) + } + + pub fn post_embedding(&self) -> http_client::Builder { + let url = format!( + "{}/models/embeddings?api-version={}", + self.azure_endpoint, self.api_version + ); + + self.post(url) + } + + pub async fn send( + &self, + req: http_client::Request, + ) -> http_client::Result>> + where + U: Into + Send, + R: From + Send + 'static, + { + self.http_client.send(req).await + } +} + +impl ProviderClient for Client +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static, +{ + /// Create a new Azure OpenAI client from the `AZURE_API_KEY` or `AZURE_TOKEN`, `AZURE_API_VERSION`, and `AZURE_ENDPOINT` environment variables. + fn from_env() -> Self { + let Ok(api_key) = std::env::var("AZURE_API_KEY") else { + panic!("Neither AZURE_API_KEY nor AZURE_TOKEN is set"); + }; + + let api_version = std::env::var("AZURE_API_VERSION").expect("AZURE_API_VERSION not set"); + let azure_endpoint = std::env::var("AZURE_ENDPOINT").expect("AZURE_ENDPOINT not set"); + + ClientBuilder::::new(&api_key, &azure_endpoint) + .api_version(&api_version) + .build() + } + + fn from_val(input: crate::client::ProviderValue) -> Self { + let crate::client::ProviderValue::ApiKeyWithVersionAndHeader(api_key, version, header) = + input + else { + panic!("Incorrect provider value type") + }; + ClientBuilder::::new(&api_key, &header) + .api_version(&version) + .build() + } +} + +impl CompletionClient for Client +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static, +{ + type CompletionModel = super::completion::CompletionModel; + + /// Create a completion model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let gpt4 = azure.completion_model(azure::GPT_4); + /// ``` + fn completion_model(&self, model: &str) -> Self::CompletionModel { + CompletionModel::new(self.clone(), model) + } +} + +impl EmbeddingsClient for Client +where + T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static, +{ + type EmbeddingModel = super::embedding::EmbeddingModel; + + /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let embedding_model = azure.embedding_model(azure::TEXT_EMBEDDING_3_LARGE); + /// ``` + fn embedding_model(&self, model: &str) -> Self::EmbeddingModel { + let ndims = 0; + EmbeddingModel::new(self.clone(), model, ndims) + } + + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. + /// + /// # Example + /// ``` + /// use rig::providers::azure::{Client, self}; + /// + /// // Initialize the Azure OpenAI client + /// let azure = Client::new("YOUR_API_KEY", "YOUR_API_VERSION", "YOUR_ENDPOINT"); + /// + /// let embedding_model = azure.embedding_model("model-unknown-to-rig", 3072); + /// ``` + fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel { + EmbeddingModel::new(self.clone(), model, ndims) + } +} + +impl_conversion_traits!( + AsTranscription, + AsImageGeneration, + AsAudioGeneration for Client +); + +#[derive(Debug, Deserialize)] +pub struct ApiErrorResponse { + pub message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, +} + +impl GetTokenUsage for Usage { + fn token_usage(&self) -> Option { + let mut usage = crate::completion::Usage::new(); + + usage.input_tokens = self.prompt_tokens as u64; + usage.total_tokens = self.total_tokens as u64; + usage.output_tokens = usage.total_tokens - usage.input_tokens; + + Some(usage) + } +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } +} diff --git a/rig-core/src/providers/azure_ai_foundry/completion.rs b/rig-core/src/providers/azure_ai_foundry/completion.rs new file mode 100644 index 000000000..db18b3d54 --- /dev/null +++ b/rig-core/src/providers/azure_ai_foundry/completion.rs @@ -0,0 +1,203 @@ +use crate::{ + json_utils::merge, + providers::{azure_ai_foundry::client::ApiResponse, openai::send_compatible_streaming_request}, + streaming::StreamingCompletionResponse, + telemetry::SpanCombinator, +}; +use bytes::Bytes; +use serde_json::json; +use tracing::Instrument; +use tracing::info_span; + +use crate::{ + completion::{self, CompletionError, CompletionRequest}, + http_client::{self, HttpClientExt}, + json_utils, + providers::{azure_ai_foundry::client::Client, openai}, +}; + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: gpt-4o-mini) + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } + + fn create_completion_request( + &self, + completion_request: CompletionRequest, + ) -> Result { + let mut full_history: Vec = match &completion_request.preamble { + Some(preamble) => vec![openai::Message::system(preamble)], + None => vec![], + }; + if let Some(docs) = completion_request.normalized_documents() { + let docs: Vec = docs.try_into()?; + full_history.extend(docs); + } + let chat_history: Vec = completion_request + .chat_history + .into_iter() + .map(|message| message.try_into()) + .collect::>, _>>()? + .into_iter() + .flatten() + .collect(); + + full_history.extend(chat_history); + + let request = if completion_request.tools.is_empty() { + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + }) + } else { + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::>(), + "tool_choice": "auto", + }) + }; + + let request = if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }; + + Ok(request) + } +} + +impl completion::CompletionModel for CompletionModel +where + T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static, +{ + type Response = openai::completion::CompletionResponse; + type StreamingResponse = openai::StreamingCompletionResponse; + + #[cfg_attr(feature = "worker", worker::send)] + async fn completion( + &self, + completion_request: CompletionRequest, + ) -> Result, CompletionError> { + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat", + gen_ai.operation.name = "chat", + gen_ai.provider.name = "azure.openai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &completion_request.preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = tracing::field::Empty, + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + let request = self.create_completion_request(completion_request)?; + span.record_model_input( + &request + .get("messages") + .expect("Converting JSON should not fail"), + ); + let body = serde_json::to_vec(&request)?; + + let req = self + .client + .post_chat_completion() + .header("Content-Type", "application/json") + .body(body) + .map_err(http_client::Error::from)?; + + async move { + let response = self.client.http_client.send::<_, Bytes>(req).await.unwrap(); + + let status = response.status(); + let response_body = response.into_body().into_future().await?.to_vec(); + + if status.is_success() { + match serde_json::from_slice::>(&response_body)? { + ApiResponse::Ok(response) => { + let span = tracing::Span::current(); + span.record_model_output(&response.choices); + span.record_response_metadata(&response); + span.record_token_usage(&response.usage); + tracing::debug!(target: "rig", "Azure completion output: {}", serde_json::to_string_pretty(&response)?); + response.try_into() + } + ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), + } + } else { + Err(CompletionError::ProviderError( + String::from_utf8_lossy(&response_body).to_string() + )) + } + } + .instrument(span) + .await + } + + #[cfg_attr(feature = "worker", worker::send)] + async fn stream( + &self, + request: CompletionRequest, + ) -> Result, CompletionError> { + let preamble = request.preamble.clone(); + let mut request = self.create_completion_request(request)?; + + request = merge( + request, + json!({"stream": true, "stream_options": {"include_usage": true}}), + ); + + let body = serde_json::to_vec(&request)?; + + let req = self + .client + .post_chat_completion() + .header("Content-Type", "application/json") + .body(body) + .map_err(http_client::Error::from)?; + + let span = if tracing::Span::current().is_disabled() { + info_span!( + target: "rig::completions", + "chat_streaming", + gen_ai.operation.name = "chat_streaming", + gen_ai.provider.name = "azure.openai", + gen_ai.request.model = self.model, + gen_ai.system_instructions = &preamble, + gen_ai.response.id = tracing::field::Empty, + gen_ai.response.model = tracing::field::Empty, + gen_ai.usage.output_tokens = tracing::field::Empty, + gen_ai.usage.input_tokens = tracing::field::Empty, + gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.output.messages = tracing::field::Empty, + ) + } else { + tracing::Span::current() + }; + + tracing_futures::Instrument::instrument( + send_compatible_streaming_request(self.client.http_client.clone(), req), + span, + ) + .await + } +} diff --git a/rig-core/src/providers/azure_ai_foundry/embedding.rs b/rig-core/src/providers/azure_ai_foundry/embedding.rs new file mode 100644 index 000000000..f162a125d --- /dev/null +++ b/rig-core/src/providers/azure_ai_foundry/embedding.rs @@ -0,0 +1,131 @@ +use serde::Deserialize; +use serde_json::json; + +use crate::{ + embeddings::{self, EmbeddingError}, + http_client::{self, HttpClientExt}, + providers::azure_ai_foundry::client::{ApiErrorResponse, ApiResponse, Client, Usage}, +}; + +// ================================================================ +// Azure OpenAI Embedding API +// ================================================================ +/// `text-embedding-3-large` embedding model +pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; +/// `text-embedding-3-small` embedding model +pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; +/// `text-embedding-ada-002` embedding model +pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, + pub model: String, + pub usage: Usage, +} + +impl From for EmbeddingError { + fn from(err: ApiErrorResponse) -> Self { + EmbeddingError::ProviderError(err.message) + } +} + +impl From> for Result { + fn from(value: ApiResponse) -> Self { + match value { + ApiResponse::Ok(response) => Ok(response), + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: usize, +} + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + pub model: String, + ndims: usize, +} + +impl embeddings::EmbeddingModel for EmbeddingModel +where + T: HttpClientExt + Default + Clone, +{ + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + self.ndims + } + + #[cfg_attr(feature = "worker", worker::send)] + async fn embed_texts( + &self, + documents: impl IntoIterator, + ) -> Result, EmbeddingError> { + let documents = documents.into_iter().collect::>(); + + let body = serde_json::to_vec(&json!({ + "input": documents, + }))?; + + let req = self + .client + .post_embedding() + .header("Content-Type", "application/json") + .body(body) + .map_err(|e| EmbeddingError::HttpError(e.into()))?; + + let response = self.client.send(req).await?; + + if response.status().is_success() { + let body: Vec = response.into_body().await?; + let body: ApiResponse = serde_json::from_slice(&body)?; + + match body { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Azure embedding token usage: {}", + response.usage + ); + + if response.data.len() != documents.len() { + return Err(EmbeddingError::ResponseError( + "Response data length does not match input length".into(), + )); + } + + Ok(response + .data + .into_iter() + .zip(documents.into_iter()) + .map(|(embedding, document)| embeddings::Embedding { + document, + vec: embedding.embedding, + }) + .collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } else { + let text = http_client::text(response).await?; + Err(EmbeddingError::ProviderError(text)) + } + } +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} diff --git a/rig-core/src/providers/azure_ai_foundry/mod.rs b/rig-core/src/providers/azure_ai_foundry/mod.rs new file mode 100644 index 000000000..0a57ad5ff --- /dev/null +++ b/rig-core/src/providers/azure_ai_foundry/mod.rs @@ -0,0 +1,5 @@ +//! Rig bindings for the Azure AI Foundry API. + +pub mod client; +pub mod completion; +pub mod embedding; diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 53ca700f5..dd7599ade 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -47,6 +47,7 @@ //! be used with the Cohere provider client. pub mod anthropic; pub mod azure; +pub mod azure_ai_foundry; pub mod cohere; pub mod deepseek; pub mod galadriel;