diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index 3b1aa7fe2..1bd0a2fb8 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -4,7 +4,6 @@ use crate::{ OneOrMany, completion::{self, CompletionError, GetTokenUsage}, http_client::HttpClientExt, - json_utils, message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning}, one_or_many::string_or_one_or_many, telemetry::{ProviderResponseExt, SpanCombinator}, @@ -17,7 +16,6 @@ use crate::completion::CompletionRequest; use crate::providers::anthropic::streaming::StreamingCompletionResponse; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use serde_json::json; use tracing::{Instrument, info_span}; // ================================================================ @@ -721,6 +719,68 @@ impl TryFrom for ToolChoice { Ok(res) } } + +#[derive(Debug, Deserialize, Serialize)] +struct AnthropicCompletionRequest { + model: String, + messages: Vec, + max_tokens: u64, + system: String, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + additional_params: Option, +} + +impl TryFrom<(&str, CompletionRequest)> for AnthropicCompletionRequest { + type Error = CompletionError; + + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + // Check if max_tokens is set, required for Anthropic + let Some(max_tokens) = req.max_tokens else { + return Err(CompletionError::RequestError( + "`max_tokens` must be set for Anthropic".into(), + )); + }; + + let mut full_history = vec![]; + if let Some(docs) = req.normalized_documents() { + full_history.push(docs); + } + full_history.extend(req.chat_history); + + let messages = full_history + .into_iter() + .map(Message::try_from) + .collect::, _>>()?; + + let tools = req + .tools + .into_iter() + .map(|tool| ToolDefinition { + name: tool.name, + description: Some(tool.description), + input_schema: tool.parameters, + }) + .collect::>(); + + Ok(Self { + model: model.to_string(), + messages, + max_tokens, + system: req.preamble.unwrap_or_default(), + temperature: req.temperature, + tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()), + tools, + additional_params: req.additional_params, + }) + } +} + impl completion::CompletionModel for CompletionModel where T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static, @@ -736,7 +796,7 @@ where #[cfg_attr(feature = "worker", worker::send)] async fn completion( &self, - completion_request: completion::CompletionRequest, + mut completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { let span = if tracing::Span::current().is_disabled() { info_span!( @@ -744,7 +804,7 @@ where "chat", gen_ai.operation.name = "chat", gen_ai.provider.name = "anthropic", - gen_ai.request.model = self.model, + gen_ai.request.model = &self.model, gen_ai.system_instructions = &completion_request.preamble, gen_ai.response.id = tracing::field::Empty, gen_ai.response.model = tracing::field::Empty, @@ -756,75 +816,21 @@ where } else { tracing::Span::current() }; - // Note: Ideally we'd introduce provider-specific Request models to handle the - // specific requirements of each provider. For now, we just manually check while - // building the request as a raw JSON document. // Check if max_tokens is set, required for Anthropic - let max_tokens = if let Some(tokens) = completion_request.max_tokens { - tokens - } else if let Some(tokens) = self.default_max_tokens { - tokens - } else { - return Err(CompletionError::RequestError( - "`max_tokens` must be set for Anthropic".into(), - )); - }; - - let mut full_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { - full_history.push(docs); - } - full_history.extend(completion_request.chat_history); - span.record_model_input(&full_history); - - let full_history = full_history - .into_iter() - .map(Message::try_from) - .collect::, _>>()?; - - let mut request = json!({ - "model": self.model, - "messages": full_history, - "max_tokens": max_tokens, - "system": completion_request.preamble.unwrap_or("".to_string()), - }); - - if let Some(temperature) = completion_request.temperature { - json_utils::merge_inplace(&mut request, json!({ "temperature": temperature })); - } - - let tool_choice = if let Some(tool_choice) = completion_request.tool_choice { - Some(ToolChoice::try_from(tool_choice)?) - } else { - None - }; - - if !completion_request.tools.is_empty() { - let mut tools_json = json!({ - "tools": completion_request - .tools - .into_iter() - .map(|tool| ToolDefinition { - name: tool.name, - description: Some(tool.description), - input_schema: tool.parameters, - }) - .collect::>(), - }); - - // Only include tool_choice if it's explicitly set (not None) - // When omitted, Anthropic defaults to "auto" - if let Some(tc) = tool_choice { - tools_json["tool_choice"] = serde_json::to_value(tc)?; + if completion_request.max_tokens.is_none() { + if let Some(tokens) = self.default_max_tokens { + completion_request.max_tokens = Some(tokens); + } else { + return Err(CompletionError::RequestError( + "`max_tokens` must be set for Anthropic".into(), + )); } - - json_utils::merge_inplace(&mut request, tools_json); } - if let Some(ref params) = completion_request.additional_params { - json_utils::merge_inplace(&mut request, params.clone()) - } + let request = + AnthropicCompletionRequest::try_from((self.model.as_ref(), completion_request))?; + span.record_model_input(&request.messages); async move { let request: Vec = serde_json::to_vec(&request)?; @@ -909,6 +915,7 @@ enum ApiResponse { #[cfg(test)] mod tests { use super::*; + use serde_json::json; use serde_path_to_error::deserialize; #[test] diff --git a/rig-core/src/providers/azure.rs b/rig-core/src/providers/azure.rs index 352e9a553..8a44b14b9 100644 --- a/rig-core/src/providers/azure.rs +++ b/rig-core/src/providers/azure.rs @@ -19,7 +19,6 @@ use crate::client::{ }; use crate::completion::GetTokenUsage; use crate::http_client::{self, HttpClientExt, bearer_auth_header}; -use crate::json_utils::merge; use crate::streaming::StreamingCompletionResponse; use crate::transcription::TranscriptionError; use crate::{ @@ -32,7 +31,7 @@ use crate::{ }; use bytes::Bytes; use reqwest::multipart::Part; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::json; // ================================================================ // Main Azure OpenAI Client @@ -544,42 +543,44 @@ pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct"; /// `gpt-3.5-turbo-16k` completion model pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k"; -#[derive(Clone)] -pub struct CompletionModel { - client: Client, - /// Name of the model (e.g.: gpt-4o-mini) - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct AzureOpenAICompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } +impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest { + type Error = CompletionError; - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + //FIXME: Must fix! + if req.tool_choice.is_some() { + tracing::warn!( + "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25." + ); } - } - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - let mut full_history: Vec = match &completion_request.preamble { + let mut full_history: Vec = match &req.preamble { Some(preamble) => vec![openai::Message::system(preamble)], None => vec![], }; - if let Some(docs) = completion_request.normalized_documents() { + + if let Some(docs) = req.normalized_documents() { let docs: Vec = docs.try_into()?; full_history.extend(docs); } - let chat_history: Vec = completion_request + + let chat_history: Vec = req .chat_history + .clone() .into_iter() .map(|message| message.try_into()) .collect::>, _>>()? @@ -589,29 +590,41 @@ impl CompletionModel { full_history.extend(chat_history); - let request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::>(), - "tool_choice": "auto", - }) - }; + let tool_choice = req + .tool_choice + .clone() + .map(crate::providers::openrouter::ToolChoice::try_from) + .transpose()?; + + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(openai::ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: gpt-4o-mini) + pub model: String, +} - Ok(request) +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } } } @@ -650,12 +663,11 @@ where } else { tracing::Span::current() }; - let request = self.create_completion_request(completion_request)?; - span.record_model_input( - &request - .get("messages") - .expect("Converting JSON should not fail"), - ); + + let request = + AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?; + + span.record_model_input(&request.messages); let body = serde_json::to_vec(&request)?; let req = self @@ -707,16 +719,19 @@ where #[cfg_attr(feature = "worker", worker::send)] async fn stream( &self, - request: CompletionRequest, + completion_request: CompletionRequest, ) -> Result, CompletionError> { - let preamble = request.preamble.clone(); - let mut request = self.create_completion_request(request)?; + let preamble = completion_request.preamble.clone(); + let mut request = + AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?; - request = merge( - request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }), ); + request.additional_params = Some(params); + let body = serde_json::to_vec(&request)?; let req = self @@ -738,7 +753,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/cohere/completion.rs b/rig-core/src/providers/cohere/completion.rs index 70ce123bd..5f5245f9e 100644 --- a/rig-core/src/providers/cohere/completion.rs +++ b/rig-core/src/providers/cohere/completion.rs @@ -12,7 +12,6 @@ use super::client::Client; use crate::completion::CompletionRequest; use crate::providers::cohere::streaming::StreamingCompletionResponse; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use tracing::{Instrument, info_span}; #[derive(Debug, Deserialize, Serialize)] @@ -520,43 +519,35 @@ pub struct CompletionModel { pub model: String, } -impl CompletionModel -where - T: HttpClientExt, -{ - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct CohereCompletionRequest { + model: String, + pub messages: Vec, + documents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, +} - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.to_string(), - } - } +impl TryFrom<(&str, CompletionRequest)> for CohereCompletionRequest { + type Error = CompletionError; - pub(crate) fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - // Build up the order of messages (context, chat_history) + fn try_from((model, req): (&str, CompletionRequest)) -> Result { let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); + partial_history.extend(req.chat_history); - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = completion_request - .preamble - .map_or_else(Vec::new, |preamble| { - vec![Message::System { content: preamble }] - }); + let mut full_history: Vec = req.preamble.map_or_else(Vec::new, |preamble| { + vec![Message::System { content: preamble }] + }); - // Convert and extend the rest of the history full_history.extend( partial_history .into_iter() @@ -567,21 +558,38 @@ where .collect::>(), ); - let request = json!({ - "model": self.model, - "messages": full_history, - "documents": completion_request.documents, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(Tool::from).collect::>(), - "tool_choice": if let Some(tool_choice) = completion_request.tool_choice && !matches!(tool_choice, ToolChoice::Auto) { tool_choice } else { - return Err(CompletionError::RequestError("\"auto\" is not an allowed tool_choice value in the Cohere API".into())) - }, - }); - - if let Some(ref params) = completion_request.additional_params { - Ok(json_utils::merge(request.clone(), params.clone())) + let tool_choice = if let Some(tool_choice) = req.tool_choice { + if !matches!(tool_choice, ToolChoice::Auto) { + Some(tool_choice) + } else { + return Err(CompletionError::RequestError( + "\"auto\" is not an allowed tool_choice value in the Cohere API".into(), + )); + } } else { - Ok(request) + None + }; + + Ok(Self { + model: model.to_string(), + messages: full_history, + documents: req.documents, + temperature: req.temperature, + tools: req.tools.into_iter().map(Tool::from).collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} + +impl CompletionModel +where + T: HttpClientExt, +{ + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), } } } @@ -603,7 +611,7 @@ where &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(completion_request)?; + let request = CohereCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let llm_span = if tracing::Span::current().is_disabled() { info_span!( @@ -616,7 +624,7 @@ where gen_ai.response.model = self.model, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(request.get("messages").expect("Converting request messages to JSON should not fail!")).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/cohere/streaming.rs b/rig-core/src/providers/cohere/streaming.rs index 253e43199..35f104c33 100644 --- a/rig-core/src/providers/cohere/streaming.rs +++ b/rig-core/src/providers/cohere/streaming.rs @@ -3,7 +3,7 @@ use crate::http_client::HttpClientExt; use crate::http_client::sse::{Event, GenericEventSource}; use crate::providers::cohere::CompletionModel; use crate::providers::cohere::completion::{ - AssistantContent, Message, ToolCall, ToolCallFunction, ToolType, Usage, + AssistantContent, CohereCompletionRequest, Message, ToolCall, ToolCallFunction, ToolType, Usage, }; use crate::streaming::RawStreamingChoice; use crate::telemetry::SpanCombinator; @@ -99,7 +99,7 @@ where request: CompletionRequest, ) -> Result, CompletionError> { - let request = self.create_completion_request(request)?; + let mut request = CohereCompletionRequest::try_from((self.model.as_ref(), request))?; let span = if tracing::Span::current().is_disabled() { info_span!( target: "rig::completions", @@ -111,14 +111,19 @@ where gen_ai.response.model = self.model, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { tracing::Span::current() }; - let request = json_utils::merge(request, serde_json::json!({"stream": true})); + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true}), + ); + + request.additional_params = Some(params); tracing::trace!( target: "rig::streaming", diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index c56a16071..c15030053 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -9,6 +9,7 @@ //! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT); //! ``` +use crate::json_utils::empty_or_none; use async_stream::stream; use bytes::Bytes; use futures::StreamExt; @@ -23,7 +24,6 @@ use crate::client::{ use crate::completion::GetTokenUsage; use crate::http_client::sse::{Event, GenericEventSource}; use crate::http_client::{self, HttpClientExt}; -use crate::json_utils::{empty_or_none, merge}; use crate::message::{Document, DocumentSourceKind}; use crate::{ OneOrMany, @@ -31,7 +31,6 @@ use crate::{ json_utils, message, }; use serde::{Deserialize, Serialize}; -use serde_json::json; use super::openai::StreamingToolCall; @@ -430,74 +429,75 @@ impl TryFrom for completion::CompletionResponse { - pub client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct DeepseekCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; +impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest { + type Error = CompletionError; - if let Some(docs) = completion_request.normalized_documents() { - partial_history.push(docs); - } + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble)], + None => vec![], + }; - partial_history.extend(completion_request.chat_history); + if let Some(docs) = req.normalized_documents() { + let docs: Vec = docs.try_into()?; + full_history.extend(docs); + } - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = completion_request - .preamble - .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); + let chat_history: Vec = req + .chat_history + .clone() + .into_iter() + .map(|message| message.try_into()) + .collect::>, _>>()? + .into_iter() + .flatten() + .collect(); - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); + full_history.extend(chat_history); - let tool_choice = completion_request + let tool_choice = req .tool_choice + .clone() .map(crate::providers::openrouter::ToolChoice::try_from) .transpose()?; - let request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; - - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; - - Ok(request) + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) } } +/// The struct implementing the `CompletionModel` trait +#[derive(Clone)] +pub struct CompletionModel { + pub client: Client, + pub model: String, +} + impl completion::CompletionModel for CompletionModel where T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static, @@ -523,7 +523,8 @@ where crate::completion::CompletionError, > { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = + DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -537,7 +538,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -600,13 +601,16 @@ where CompletionError, > { let preamble = completion_request.preamble.clone(); - let mut request = self.create_completion_request(completion_request)?; + let mut request = + DeepseekCompletionRequest::try_from((self.model.as_ref(), completion_request))?; - request = merge( - request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }), ); + request.additional_params = Some(params); + let body = serde_json::to_vec(&request)?; let req = self @@ -628,7 +632,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/galadriel.rs b/rig-core/src/providers/galadriel.rs index f718da166..41146e30c 100644 --- a/rig-core/src/providers/galadriel.rs +++ b/rig-core/src/providers/galadriel.rs @@ -16,7 +16,6 @@ use crate::client::{ ProviderClient, }; use crate::http_client::{self, HttpClientExt}; -use crate::json_utils::merge; use crate::message::MessageError; use crate::providers::openai::send_compatible_streaming_request; use crate::streaming::StreamingCompletionResponse; @@ -26,7 +25,6 @@ use crate::{ json_utils, message, }; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use tracing::{Instrument, info_span}; // ================================================================ @@ -291,6 +289,16 @@ pub struct Message { pub tool_calls: Vec, } +impl Message { + fn system(preamble: &str) -> Self { + Self { + role: "system".to_string(), + content: Some(preamble.to_string()), + tool_calls: Vec::new(), + } + } +} + impl TryFrom for message::Message { type Error = message::MessageError; @@ -409,49 +417,34 @@ pub struct Function { pub arguments: String, } -#[derive(Clone)] -pub struct CompletionModel { - client: Client, - /// Name of the model (e.g.: gpt-3.5-turbo-1106) - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct GaladrielCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel -where - T: HttpClientExt, -{ - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } - - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } +impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { + type Error = CompletionError; - pub(crate) fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { + fn try_from((model, req): (&str, CompletionRequest)) -> Result { // Build up the order of messages (context, chat_history, prompt) let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); + partial_history.extend(req.chat_history); // Add preamble to chat history (if available) - let mut full_history: Vec = match &completion_request.preamble { - Some(preamble) => vec![Message { - role: "system".to_string(), - content: Some(preamble.to_string()), - tool_calls: vec![], - }], + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble)], None => vec![], }; @@ -463,35 +456,51 @@ where .collect::, _>>()?, ); - let tool_choice = completion_request + let tool_choice = req .tool_choice .clone() .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; - let request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: gpt-3.5-turbo-1106) + pub model: String, +} - Ok(request) +impl CompletionModel +where + T: HttpClientExt, +{ + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } + } + + pub fn with_model(client: Client, model: &str) -> Self { + Self { + client, + model: model.into(), + } } } @@ -514,7 +523,8 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = + GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let body = serde_json::to_vec(&request)?; let req = self @@ -536,7 +546,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -583,16 +593,19 @@ where #[cfg_attr(feature = "worker", worker::send)] async fn stream( &self, - request: CompletionRequest, + completion_request: CompletionRequest, ) -> Result, CompletionError> { - let preamble = request.preamble.clone(); - let mut request = self.create_completion_request(request)?; + let preamble = completion_request.preamble.clone(); + let mut request = + GaladrielCompletionRequest::try_from((self.model.as_ref(), completion_request))?; - request = merge( - request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }), ); + request.additional_params = Some(params); + let body = serde_json::to_vec(&request)?; let req = self @@ -614,7 +627,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 0af54ba2d..7439a995f 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -22,7 +22,7 @@ use crate::client::{ use crate::completion::GetTokenUsage; use crate::http_client::sse::{Event, GenericEventSource}; use crate::http_client::{self, HttpClientExt}; -use crate::json_utils::{empty_or_none, merge}; +use crate::json_utils::empty_or_none; use crate::providers::openai::{AssistantContent, Function, ToolType}; use async_stream::stream; use futures::StreamExt; @@ -37,7 +37,6 @@ use crate::{ }; use reqwest::multipart::Part; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; // ================================================================ // Main Groq Client @@ -125,6 +124,16 @@ pub struct Message { pub reasoning: Option, } +impl Message { + fn system(preamble: &str) -> Self { + Self { + role: "system".to_string(), + content: Some(preamble.to_string()), + reasoning: None, + } + } +} + impl TryFrom for message::Message { type Error = message::MessageError; @@ -247,50 +256,43 @@ pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192"; /// The `mixtral-8x7b-32768` model. Used for chat completion. pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768"; -#[derive(Clone, Debug)] -pub struct CompletionModel { - client: Client, - /// Name of the model (e.g.: deepseek-r1-distill-llama-70b) - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ReasoningFormat { + Parsed, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct GroqCompletionRequest { + model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, + reasoning_format: ReasoningFormat, +} - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } +impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { + type Error = CompletionError; - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { + fn try_from((model, req): (&str, CompletionRequest)) -> Result { // Build up the order of messages (context, chat_history, prompt) let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); - - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = - completion_request - .preamble - .map_or_else(Vec::new, |preamble| { - vec![Message { - role: "system".to_string(), - content: Some(preamble), - reasoning: None, - }] - }); + partial_history.extend(req.chat_history); + + // Add preamble to chat history (if available) + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble)], + None => vec![], + }; // Convert and extend the rest of the history full_history.extend( @@ -300,35 +302,42 @@ impl CompletionModel { .collect::, _>>()?, ); - let tool_choice = completion_request + let tool_choice = req .tool_choice + .clone() .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - "reasoning_format": "parsed" - }) - }; + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + reasoning_format: ReasoningFormat::Parsed, + }) + } +} - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; +#[derive(Clone, Debug)] +pub struct CompletionModel { + client: Client, + /// Name of the model (e.g.: deepseek-r1-distill-llama-70b) + pub model: String, +} - Ok(request) +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } } } @@ -352,7 +361,7 @@ where ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = GroqCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( target: "rig::completions", @@ -365,7 +374,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -393,7 +402,7 @@ where span.record("gen_ai.response.model_name", response.model.clone()); span.record( "gen_ai.output.messages", - serde_json::to_string(&response.choices).unwrap(), + serde_json::to_string(&response.choices)?, ); if let Some(ref usage) = response.usage { span.record("gen_ai.usage.input_tokens", usage.prompt_tokens); @@ -425,13 +434,15 @@ where CompletionError, > { let preamble = request.preamble.clone(); - let mut request = self.create_completion_request(request)?; + let mut request = GroqCompletionRequest::try_from((self.model.as_ref(), request))?; - request = merge( - request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }), ); + request.additional_params = Some(params); + let body = serde_json::to_vec(&request)?; let req = self .client @@ -452,7 +463,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/huggingface/completion.rs b/rig-core/src/providers/huggingface/completion.rs index 99946dd5d..6f3c06e7e 100644 --- a/rig-core/src/providers/huggingface/completion.rs +++ b/rig-core/src/providers/huggingface/completion.rs @@ -11,7 +11,7 @@ use crate::{ one_or_many::string_or_one_or_many, }; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use serde_json::{Value, json}; +use serde_json::Value; use std::{convert::Infallible, str::FromStr}; use tracing::info_span; @@ -608,35 +608,34 @@ impl TryFrom for completion::CompletionResponse { - pub(crate) client: Client, - /// Name of the model (e.g: google/gemma-2-2b-it) - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct HuggingfaceCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { - Self { - client, - model: model.to_string(), - } - } +impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest { + type Error = CompletionError; - pub(crate) fn create_request_body( - &self, - completion_request: &CompletionRequest, - ) -> Result { - let mut full_history: Vec = match &completion_request.preamble { + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let mut full_history: Vec = match &req.preamble { Some(preamble) => vec![Message::system(preamble)], None => vec![], }; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { let docs: Vec = docs.try_into()?; full_history.extend(docs); } - let chat_history: Vec = completion_request + let chat_history: Vec = req .chat_history .clone() .into_iter() @@ -648,30 +647,41 @@ impl CompletionModel { full_history.extend(chat_history); - let model = self.client.subprovider().model_identifier(&self.model); - - let tool_choice = completion_request + let tool_choice = req .tool_choice .clone() .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; - let request = if completion_request.tools.is_empty() { - json!({ - "model": model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.clone().into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; - Ok(request) + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} + +#[derive(Clone)] +pub struct CompletionModel { + pub(crate) client: Client, + /// Name of the model (e.g: google/gemma-2-2b-it) + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } } } @@ -711,19 +721,15 @@ where } else { tracing::Span::current() }; - let request = self.create_request_body(&completion_request)?; - span.record_model_input(&request.get("messages")); - let path = self.client.subprovider().completion_endpoint(&self.model); + let model = self.client.subprovider().model_identifier(&self.model); + let request = HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?; - let request = if let Some(ref params) = completion_request.additional_params { - json_utils::merge(request, params.clone()) - } else { - request - }; + span.record_model_input(&request.messages); let request = serde_json::to_vec(&request)?; + let path = self.client.subprovider().completion_endpoint(&self.model); let request = self .client .post(&path)? diff --git a/rig-core/src/providers/huggingface/streaming.rs b/rig-core/src/providers/huggingface/streaming.rs index 9ef01f0c1..2e09b13e0 100644 --- a/rig-core/src/providers/huggingface/streaming.rs +++ b/rig-core/src/providers/huggingface/streaming.rs @@ -1,10 +1,10 @@ use super::completion::CompletionModel; use crate::completion::{CompletionError, CompletionRequest}; use crate::http_client::HttpClientExt; -use crate::json_utils::merge_inplace; +use crate::json_utils::{self}; +use crate::providers::huggingface::completion::HuggingfaceCompletionRequest; use crate::providers::openai::{StreamingCompletionResponse, send_compatible_streaming_request}; use crate::streaming; -use serde_json::json; use tracing::{Instrument, info_span}; impl CompletionModel @@ -16,17 +16,16 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { - let mut request = self.create_request_body(&completion_request)?; + let model = self.client.subprovider().model_identifier(&self.model); + let mut request = + HuggingfaceCompletionRequest::try_from((model.as_ref(), completion_request))?; - // Enable streaming - merge_inplace( - &mut request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true }}), ); - if let Some(ref params) = completion_request.additional_params { - merge_inplace(&mut request, params.clone()); - } + request.additional_params = Some(params); // HF Inference API uses the model in the path even though its specified in the request body let path = self.client.subprovider().completion_endpoint(&self.model); @@ -51,7 +50,7 @@ where gen_ai.response.model = self.model, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request["messages"]).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/hyperbolic.rs b/rig-core/src/providers/hyperbolic.rs index fb746015c..3e57a7e95 100644 --- a/rig-core/src/providers/hyperbolic.rs +++ b/rig-core/src/providers/hyperbolic.rs @@ -13,8 +13,6 @@ use super::openai::{AssistantContent, send_compatible_streaming_request}; use crate::client::{self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder}; use crate::client::{BearerAuth, ProviderClient}; use crate::http_client::{self, HttpClientExt}; -use crate::json_utils::merge_inplace; -use crate::message; use crate::streaming::StreamingCompletionResponse; use crate::providers::openai; @@ -25,7 +23,6 @@ use crate::{ providers::openai::Message, }; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; // ================================================================ // Main Hyperbolic Client @@ -249,6 +246,59 @@ pub struct Choice { pub finish_reason: String, } +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct HyperbolicCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, +} + +impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest { + type Error = CompletionError; + + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + if req.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic"); + } + + if !req.tools.is_empty() { + tracing::warn!("WARNING: `tools` not supported on Hyperbolic"); + } + + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble)], + None => vec![], + }; + + if let Some(docs) = req.normalized_documents() { + let docs: Vec = docs.try_into()?; + full_history.extend(docs); + } + + let chat_history: Vec = req + .chat_history + .clone() + .into_iter() + .map(|message| message.try_into()) + .collect::>, _>>()? + .into_iter() + .flatten() + .collect(); + + full_history.extend(chat_history); + + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + additional_params: req.additional_params, + }) + } +} + #[derive(Clone)] pub struct CompletionModel { client: Client, @@ -270,51 +320,6 @@ impl CompletionModel { model: model.into(), } } - - pub(crate) fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic"); - } - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(completion_request.chat_history); - - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = completion_request - .preamble - .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); - - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); - - let request = json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }); - - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; - - Ok(request) - } } impl completion::CompletionModel for CompletionModel @@ -336,7 +341,8 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = + HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let body = serde_json::to_vec(&request)?; let span = if tracing::Span::current().is_disabled() { @@ -351,7 +357,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -403,7 +409,8 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let mut request = self.create_completion_request(completion_request)?; + let mut request = + HyperbolicCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -417,18 +424,20 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { tracing::Span::current() }; - merge_inplace( - &mut request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }), ); + request.additional_params = Some(params); + let body = serde_json::to_vec(&request)?; let req = self diff --git a/rig-core/src/providers/mira.rs b/rig-core/src/providers/mira.rs index 137126797..62e522bac 100644 --- a/rig-core/src/providers/mira.rs +++ b/rig-core/src/providers/mira.rs @@ -12,7 +12,6 @@ use crate::client::{ ProviderClient, }; use crate::http_client::{self, HttpClientExt}; -use crate::json_utils::merge; use crate::message::{Document, DocumentSourceKind}; use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; @@ -23,7 +22,6 @@ use crate::{ message::{self, AssistantContent, Message, UserContent}, }; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use std::string::FromUtf8Error; use thiserror::Error; use tracing::{self, Instrument, info_span}; @@ -206,48 +204,31 @@ impl ProviderClient for Client { } } -#[derive(Clone)] -pub struct CompletionModel { - client: Client, - /// Name of the model - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct MiraCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + max_tokens: Option, + pub stream: bool, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } - - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } - - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); - } +impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { + type Error = CompletionError; + fn try_from((model, req): (&str, CompletionRequest)) -> Result { let mut messages = Vec::new(); - // Add preamble as user message if available - if let Some(preamble) = &completion_request.preamble { - messages.push(serde_json::json!({ - "role": "user", - "content": preamble.to_string() - })); + if let Some(content) = &req.preamble { + messages.push(RawMessage { + role: "user".to_string(), + content: content.to_string(), + }); } - // Add docs - if let Some(Message::User { content }) = completion_request.normalized_documents() { + if let Some(Message::User { content }) = req.normalized_documents() { let text = content .into_iter() .filter_map(|doc| match doc { @@ -263,14 +244,13 @@ impl CompletionModel { .collect::>() .join("\n"); - messages.push(serde_json::json!({ - "role": "user", - "content": text - })); + messages.push(RawMessage { + role: "user".to_string(), + content: text, + }); } - // Add chat history - for msg in completion_request.chat_history { + for msg in req.chat_history { let (role, content) = match msg { Message::User { content } => { let text = content @@ -295,21 +275,35 @@ impl CompletionModel { ("assistant", text) } }; - messages.push(serde_json::json!({ - "role": role, - "content": content - })); + messages.push(RawMessage { + role: role.to_string(), + content, + }); } - let request = serde_json::json!({ - "model": self.model, - "messages": messages, - "temperature": completion_request.temperature.map(|t| t as f32).unwrap_or(0.7), - "max_tokens": completion_request.max_tokens.map(|t| t as u32).unwrap_or(100), - "stream": false - }); + Ok(Self { + model: model.to_string(), + messages, + temperature: req.temperature, + max_tokens: req.max_tokens, + stream: false, + }) + } +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + /// Name of the model + pub model: String, +} - Ok(request) +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } } } @@ -333,14 +327,21 @@ where ) -> Result, CompletionError> { if !completion_request.tools.is_empty() { tracing::warn!(target: "rig::completions", - "Tool calls are not supported by the Mira provider. {len} tools will be ignored.", + "Tool calls are not supported by Mira AI. {len} tools will be ignored.", len = completion_request.tools.len() ); } - let preamble = completion_request.preamble.clone(); + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); + } + + if completion_request.additional_params.is_some() { + tracing::warn!("WARNING: Additional parameters not supported on Mira AI"); + } - let request = self.create_completion_request(completion_request)?; + let preamble = completion_request.preamble.clone(); + let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -354,7 +355,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -426,8 +427,24 @@ where &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { + if !completion_request.tools.is_empty() { + tracing::warn!(target: "rig::completions", + "Tool calls are not supported by Mira AI. {len} tools will be ignored.", + len = completion_request.tools.len() + ); + } + + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); + } + + if completion_request.additional_params.is_some() { + tracing::warn!("WARNING: Additional parameters not supported on Mira AI"); + } let preamble = completion_request.preamble.clone(); - let mut request = self.create_completion_request(completion_request)?; + let mut request = + MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?; + request.stream = true; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -441,13 +458,12 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { tracing::Span::current() }; - request = merge(request, json!({"stream": true})); let body = serde_json::to_vec(&request)?; let req = self diff --git a/rig-core/src/providers/mistral/completion.rs b/rig-core/src/providers/mistral/completion.rs index 0b9638e31..3e604ea04 100644 --- a/rig-core/src/providers/mistral/completion.rs +++ b/rig-core/src/providers/mistral/completion.rs @@ -1,6 +1,5 @@ use async_stream::stream; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use std::{convert::Infallible, str::FromStr}; use tracing::{Instrument, info_span}; @@ -285,6 +284,67 @@ impl TryFrom for ToolChoice { } } +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct MistralCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, +} + +impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest { + type Error = CompletionError; + + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble.clone())], + None => vec![], + }; + if let Some(docs) = req.normalized_documents() { + let docs: Vec = docs.try_into()?; + full_history.extend(docs); + } + + let chat_history: Vec = req + .chat_history + .clone() + .into_iter() + .map(|message| message.try_into()) + .collect::>, _>>()? + .into_iter() + .flatten() + .collect(); + + full_history.extend(chat_history); + + let tool_choice = req + .tool_choice + .clone() + .map(crate::providers::openai::completion::ToolChoice::try_from) + .transpose()?; + + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} + impl CompletionModel { pub fn new(client: Client, model: impl Into) -> Self { Self { @@ -299,72 +359,6 @@ impl CompletionModel { model: model.into(), } } - - pub(crate) fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { - partial_history.push(docs); - } - - partial_history.extend(completion_request.chat_history); - - let mut full_history: Vec = match &completion_request.preamble { - Some(preamble) => vec![Message::system(preamble.clone())], - None => vec![], - }; - - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); - - let tool_choice = completion_request - .tool_choice - .map(ToolChoice::try_from) - .transpose()?; - - let request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; - - let request = if let Some(temperature) = completion_request.temperature { - json_utils::merge( - request, - json!({ - "temperature": temperature, - }), - ) - } else { - request - }; - - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; - - Ok(request) - } } #[derive(Debug, Deserialize, Clone, Serialize)] @@ -513,8 +507,8 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; - let body = serde_json::to_vec(&request)?; + let request = + MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -528,13 +522,15 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { tracing::Span::current() }; + let body = serde_json::to_vec(&request)?; + let request = self .client .post("v1/chat/completions")? diff --git a/rig-core/src/providers/moonshot.rs b/rig-core/src/providers/moonshot.rs index f98366692..7f7b4bbed 100644 --- a/rig-core/src/providers/moonshot.rs +++ b/rig-core/src/providers/moonshot.rs @@ -13,7 +13,6 @@ use crate::client::{ ProviderClient, }; use crate::http_client::HttpClientExt; -use crate::json_utils::merge; use crate::providers::openai::send_compatible_streaming_request; use crate::streaming::StreamingCompletionResponse; use crate::{ @@ -23,7 +22,6 @@ use crate::{ }; use crate::{http_client, message}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use tracing::{Instrument, info_span}; // ================================================================ @@ -115,44 +113,38 @@ enum ApiResponse { pub const MOONSHOT_CHAT: &str = "moonshot-v1-128k"; -#[derive(Clone)] -pub struct CompletionModel { - client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct MoonshotCompletionRequest { + model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } - - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } +impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest { + type Error = CompletionError; - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - // Build up the order of messages (context, chat_history) + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + // Build up the order of messages (context, chat_history, prompt) let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); + partial_history.extend(req.chat_history); - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = completion_request - .preamble - .map_or_else(Vec::new, |preamble| { - vec![openai::Message::system(&preamble)] - }); + // Add preamble to chat history (if available) + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![openai::Message::system(preamble)], + None => vec![], + }; // Convert and extend the rest of the history full_history.extend( @@ -165,36 +157,41 @@ impl CompletionModel { .collect::>(), ); - let tool_choice = completion_request + let tool_choice = req .tool_choice - .map(ToolChoice::try_from) + .clone() + .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "max_tokens": completion_request.max_tokens, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "max_tokens": completion_request.max_tokens, - "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + max_tokens: req.max_tokens, + tools: req + .tools + .clone() + .into_iter() + .map(openai::ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} - Ok(request) +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } } } @@ -217,9 +214,10 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = + MoonshotCompletionRequest::try_from((self.model.as_ref(), completion_request))?; - println!( + tracing::trace!( "Moonshot API input: {request}", request = serde_json::to_string_pretty(&request).unwrap() ); @@ -236,7 +234,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -305,7 +303,7 @@ where request: CompletionRequest, ) -> Result, CompletionError> { let preamble = request.preamble.clone(); - let mut request = self.create_completion_request(request)?; + let mut request = MoonshotCompletionRequest::try_from((self.model.as_ref(), request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -319,18 +317,20 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { tracing::Span::current() }; - request = merge( - request, - json!({"stream": true, "stream_options": {"include_usage": true}}), + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }), ); + request.additional_params = Some(params); + let body = serde_json::to_vec(&request)?; let req = self .client diff --git a/rig-core/src/providers/ollama.rs b/rig-core/src/providers/ollama.rs index 17ca457bb..05fdd33ab 100644 --- a/rig-core/src/providers/ollama.rs +++ b/rig-core/src/providers/ollama.rs @@ -43,7 +43,6 @@ use crate::client::{ }; use crate::completion::{GetTokenUsage, Usage}; use crate::http_client::{self, HttpClientExt}; -use crate::json_utils::merge_inplace; use crate::message::DocumentSourceKind; use crate::streaming::RawStreamingChoice; use crate::{ @@ -355,87 +354,96 @@ impl TryFrom for completion::CompletionResponse { - client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct OllamaCompletionRequest { + model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + pub stream: bool, + think: bool, + #[serde(skip_serializing_if = "Option::is_none")] + max_tokens: Option, + options: serde_json::Value, } -impl CompletionModel { - pub fn new(client: Client, model: &str) -> Self { - Self { - client, - model: model.to_owned(), - } - } +impl TryFrom<(&str, CompletionRequest)> for OllamaCompletionRequest { + type Error = CompletionError; - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - if completion_request.tool_choice.is_some() { + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + if req.tool_choice.is_some() { tracing::warn!("WARNING: `tool_choice` not supported for Ollama"); } - - // Build up the order of messages (context, chat_history) + // Build up the order of messages (context, chat_history, prompt) let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); + partial_history.extend(req.chat_history); - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = completion_request - .preamble - .map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); + // Add preamble to chat history (if available) + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble)], + None => vec![], + }; // Convert and extend the rest of the history full_history.extend( partial_history .into_iter() - .map(|msg| msg.try_into()) + .map(message::Message::try_into) .collect::>, _>>()? .into_iter() .flatten() - .collect::>(), + .collect::>(), ); - let mut request_payload = json!({ - "model": self.model, - "messages": full_history, - "stream": false, - }); + let mut think = false; - // Convert internal prompt into a provider Message - let options = if let Some(mut extra) = completion_request.additional_params { + // TODO: Fix this up to include the full range of ollama options + let options = if let Some(mut extra) = req.additional_params { if extra.get("think").is_some() { - request_payload["think"] = extra["think"].take(); + think = extra["think"].take().as_bool().ok_or_else(|| { + CompletionError::RequestError("`think` must be a bool".into()) + })?; } - json_utils::merge( - json!({ "temperature": completion_request.temperature }), - extra, - ) + json_utils::merge(json!({ "temperature": req.temperature }), extra) } else { - json!({ "temperature": completion_request.temperature }) + json!({ "temperature": req.temperature }) }; - request_payload["options"] = options; - - if !completion_request.tools.is_empty() { - request_payload["tools"] = json!( - completion_request - .tools - .into_iter() - .map(|tool| tool.into()) - .collect::>() - ); - } + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + max_tokens: req.max_tokens, + stream: false, + think, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + options, + }) + } +} - tracing::debug!(target: "rig", "Chat mode payload: {}", request_payload); +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} - Ok(request_payload) +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_owned(), + } } } @@ -484,7 +492,7 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = OllamaCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -498,7 +506,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -563,8 +571,8 @@ where ) -> Result, CompletionError> { let preamble = request.preamble.clone(); - let mut request = self.create_completion_request(request)?; - merge_inplace(&mut request, json!({"stream": true})); + let mut request = OllamaCompletionRequest::try_from((self.model.as_ref(), request))?; + request.stream = true; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -578,7 +586,7 @@ where gen_ai.response.model = self.model, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/openrouter/completion.rs b/rig-core/src/providers/openrouter/completion.rs index ab1ed2c7e..e92c437ad 100644 --- a/rig-core/src/providers/openrouter/completion.rs +++ b/rig-core/src/providers/openrouter/completion.rs @@ -14,7 +14,6 @@ use crate::{ }; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use tracing::{Instrument, info_span}; // ================================================================ @@ -307,46 +306,36 @@ pub enum ToolChoiceFunctionKind { Function { name: String }, } -#[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct OpenrouterCompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } - - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } +impl TryFrom<(&str, CompletionRequest)> for OpenrouterCompletionRequest { + type Error = CompletionError; - pub(crate) fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - // Add preamble to chat history (if available) - let mut full_history: Vec = match &completion_request.preamble { + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let mut full_history: Vec = match &req.preamble { Some(preamble) => vec![Message::system(preamble)], None => vec![], }; - - // Gather docs - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { let docs: Vec = docs.try_into()?; full_history.extend(docs); } - // Convert existing chat history - let chat_history: Vec = completion_request + let chat_history: Vec = req .chat_history + .clone() .into_iter() .map(|message| message.try_into()) .collect::>, _>>()? @@ -354,41 +343,42 @@ impl CompletionModel { .flatten() .collect(); - // Combine all messages into a single history full_history.extend(chat_history); - let tool_choice = completion_request + let tool_choice = req .tool_choice - .map(ToolChoice::try_from) + .clone() + .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; - let mut request = json!({ - "model": self.model, - "messages": full_history, - }); + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(crate::providers::openai::completion::ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} - if let Some(temperature) = completion_request.temperature { - request["temperature"] = json!(temperature); - } +#[derive(Clone)] +pub struct CompletionModel { + pub(crate) client: Client, + pub model: String, +} - if !completion_request.tools.is_empty() { - request["tools"] = json!( - completion_request - .tools - .into_iter() - .map(crate::providers::openai::completion::ToolDefinition::from) - .collect::>() - ); - request["tool_choice"] = json!(tool_choice); +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), } - - let request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; - - Ok(request) } } @@ -411,7 +401,8 @@ where completion_request: CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = + OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( target: "rig::completions", @@ -424,7 +415,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/openrouter/streaming.rs b/rig-core/src/providers/openrouter/streaming.rs index 0940a07b7..5d2f27089 100644 --- a/rig-core/src/providers/openrouter/streaming.rs +++ b/rig-core/src/providers/openrouter/streaming.rs @@ -4,7 +4,7 @@ use async_stream::stream; use futures::StreamExt; use http::Request; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::Value; use tracing::info_span; use tracing_futures::Instrument; @@ -14,6 +14,7 @@ use crate::http_client::HttpClientExt; use crate::http_client::sse::{Event, GenericEventSource}; use crate::json_utils; use crate::message::{ToolCall, ToolFunction}; +use crate::providers::openrouter::OpenrouterCompletionRequest; use crate::streaming; #[derive(Clone, Serialize, Deserialize, Debug)] @@ -114,9 +115,15 @@ where ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let mut request = + OpenrouterCompletionRequest::try_from((self.model.as_ref(), completion_request))?; - let request = json_utils::merge(request, json!({"stream": true})); + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true }), + ); + + request.additional_params = Some(params); let body = serde_json::to_vec(&request)?; @@ -139,7 +146,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 7396ef4b2..480c34db5 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -10,7 +10,6 @@ //! ``` use crate::client::BearerAuth; use crate::completion::CompletionRequest; -use crate::json_utils::merge; use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; use crate::streaming::StreamingCompletionResponse; @@ -21,11 +20,9 @@ use crate::{ }, completion::{self, CompletionError, MessageError, message}, http_client::{self, HttpClientExt}, - json_utils, }; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; use tracing::{Instrument, info_span}; // ================================================================ @@ -200,52 +197,36 @@ impl TryFrom for completion::CompletionResponse { - client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct PerplexityCompletionRequest { + model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + additional_params: Option, + pub stream: bool, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } - - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } - - fn create_completion_request( - &self, - completion_request: CompletionRequest, - ) -> Result { - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); - } +impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest { + type Error = CompletionError; - // Build up the order of messages (context, chat_history, prompt) + fn try_from((model, req): (&str, CompletionRequest)) -> Result { let mut partial_history = vec![]; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { partial_history.push(docs); } - partial_history.extend(completion_request.chat_history); + partial_history.extend(req.chat_history); // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = - completion_request - .preamble - .map_or_else(Vec::new, |preamble| { - vec![Message { - role: Role::System, - content: preamble, - }] - }); + let mut full_history: Vec = req.preamble.map_or_else(Vec::new, |preamble| { + vec![Message { + role: Role::System, + content: preamble, + }] + }); // Convert and extend the rest of the history full_history.extend( @@ -255,20 +236,29 @@ impl CompletionModel { .collect::, _>>()?, ); - // Compose request - let request = json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }); + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + max_tokens: req.max_tokens, + additional_params: req.additional_params, + stream: false, + }) + } +} - let request = if let Some(ref params) = completion_request.additional_params { - json_utils::merge(request, params.clone()) - } else { - request - }; +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} - Ok(request) +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } } } @@ -350,8 +340,16 @@ where &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); + } + + if !completion_request.tools.is_empty() { + tracing::warn!("WARNING: `tools` not supported on Perplexity"); + } let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; + let request = + PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -365,7 +363,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { @@ -421,10 +419,20 @@ where &self, completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { + if completion_request.tool_choice.is_some() { + tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); + } + + if !completion_request.tools.is_empty() { + tracing::warn!("WARNING: `tools` not supported on Perplexity"); + } + let preamble = completion_request.preamble.clone(); - let mut request = self.create_completion_request(completion_request)?; - request = merge(request, json!({"stream": true})); + let mut request = + PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?; + request.stream = true; + let body = serde_json::to_vec(&request)?; let req = self @@ -446,7 +454,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/together/completion.rs b/rig-core/src/providers/together/completion.rs index 46b2f74dc..4df314c6e 100644 --- a/rig-core/src/providers/together/completion.rs +++ b/rig-core/src/providers/together/completion.rs @@ -6,7 +6,6 @@ use crate::{ completion::{self, CompletionError}, http_client::HttpClientExt, - json_utils, providers::openai, }; @@ -15,7 +14,6 @@ use crate::completion::CompletionRequest; use crate::streaming::StreamingCompletionResponse; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use serde_json::json; use tracing::{Instrument, info_span}; // ================================================================ @@ -131,39 +129,34 @@ pub const WIZARDLM_13B_V1_2: &str = "WizardLM/WizardLM-13B-V1.2"; // Rig Implementation Types // ================================================================= -#[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct TogetherAICompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } +impl TryFrom<(&str, CompletionRequest)> for TogetherAICompletionRequest { + type Error = CompletionError; - pub fn with_model(client: Client, model: &str) -> Self { - Self { - client, - model: model.into(), - } - } - pub(crate) fn create_completion_request( - &self, - completion_request: completion::CompletionRequest, - ) -> Result { - let mut full_history: Vec = match &completion_request.preamble { + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let mut full_history: Vec = match &req.preamble { Some(preamble) => vec![openai::Message::system(preamble)], None => vec![], }; - if let Some(docs) = completion_request.normalized_documents() { + if let Some(docs) = req.normalized_documents() { let docs: Vec = docs.try_into()?; full_history.extend(docs); } - let chat_history: Vec = completion_request + + let chat_history: Vec = req .chat_history .into_iter() .map(|message| message.try_into()) @@ -174,32 +167,40 @@ impl CompletionModel { full_history.extend(chat_history); - let tool_choice = completion_request + let tool_choice = req .tool_choice + .clone() .map(ToolChoice::try_from) .transpose()?; - let mut request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; - request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; - Ok(request) + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(crate::providers::openai::completion::ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) + } +} + +#[derive(Clone)] +pub struct CompletionModel { + pub(crate) client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { + Self { + client, + model: model.into(), + } } } @@ -222,9 +223,11 @@ where completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; - let messages_as_json_string = - serde_json::to_string(request.get("messages").unwrap()).unwrap(); + let request = TogetherAICompletionRequest::try_from(( + self.model.to_string().as_ref(), + completion_request, + ))?; + let messages_as_json_string = serde_json::to_string(&request.messages)?; let span = if tracing::Span::current().is_disabled() { info_span!( @@ -232,7 +235,7 @@ where "chat", gen_ai.operation.name = "chat", gen_ai.provider.name = "together", - gen_ai.request.model = self.model, + gen_ai.request.model = self.model.to_string(), gen_ai.system_instructions = preamble, gen_ai.response.id = tracing::field::Empty, gen_ai.response.model = tracing::field::Empty, diff --git a/rig-core/src/providers/together/streaming.rs b/rig-core/src/providers/together/streaming.rs index 65a3005d6..0a79912d7 100644 --- a/rig-core/src/providers/together/streaming.rs +++ b/rig-core/src/providers/together/streaming.rs @@ -1,14 +1,11 @@ -use serde_json::json; - use super::completion::CompletionModel; +use crate::completion::{CompletionError, CompletionRequest}; use crate::http_client::HttpClientExt; +use crate::json_utils; use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; +use crate::providers::together::completion::TogetherAICompletionRequest; use crate::streaming::StreamingCompletionResponse; -use crate::{ - completion::{CompletionError, CompletionRequest}, - json_utils::merge, -}; use tracing::{Instrument, info_span}; @@ -22,9 +19,17 @@ where ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let mut request = self.create_completion_request(completion_request)?; + let mut request = TogetherAICompletionRequest::try_from(( + self.model.to_string().as_ref(), + completion_request, + ))?; + + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream_tokens": true }), + ); - request = merge(request, json!({"stream_tokens": true})); + request.additional_params = Some(params); let body = serde_json::to_vec(&request)?; @@ -41,13 +46,13 @@ where "chat_streaming", gen_ai.operation.name = "chat_streaming", gen_ai.provider.name = "together", - gen_ai.request.model = self.model, + gen_ai.request.model = self.model.to_string(), gen_ai.system_instructions = preamble, gen_ai.response.id = tracing::field::Empty, gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else { diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs index 23a6e7a9e..001eb0258 100644 --- a/rig-core/src/providers/xai/completion.rs +++ b/rig-core/src/providers/xai/completion.rs @@ -6,7 +6,6 @@ use crate::{ completion::{self, CompletionError}, http_client::HttpClientExt, - json_utils, providers::openai::Message, }; @@ -15,7 +14,7 @@ use crate::completion::CompletionRequest; use crate::providers::openai; use crate::streaming::StreamingCompletionResponse; use bytes::Bytes; -use serde_json::{Value, json}; +use serde::{Deserialize, Serialize}; use tracing::{Instrument, info_span}; use xai_api_types::{CompletionResponse, ToolDefinition}; @@ -29,30 +28,36 @@ pub const GROK_3_MINI_FAST: &str = "grok-3-mini-fast"; pub const GROK_2_IMAGE_1212: &str = "grok-2-image-1212"; pub const GROK_4: &str = "grok-4-0709"; -// ================================================================= -// Rig Implementation Types -// ================================================================= - -#[derive(Clone)] -pub struct CompletionModel { - pub(crate) client: Client, - pub model: String, +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct XAICompletionRequest { + model: String, + pub messages: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + tool_choice: Option, + #[serde(flatten, skip_serializing_if = "Option::is_none")] + pub additional_params: Option, } -impl CompletionModel { - pub(crate) fn create_completion_request( - &self, - completion_request: completion::CompletionRequest, - ) -> Result { - // Convert documents into user message - let docs: Option> = completion_request - .normalized_documents() - .map(|docs| docs.try_into()) - .transpose()?; +impl TryFrom<(&str, CompletionRequest)> for XAICompletionRequest { + type Error = CompletionError; - // Convert existing chat history - let chat_history: Vec = completion_request + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let mut full_history: Vec = match &req.preamble { + Some(preamble) => vec![Message::system(preamble)], + None => vec![], + }; + if let Some(docs) = req.normalized_documents() { + let docs: Vec = docs.try_into()?; + full_history.extend(docs); + } + + let chat_history: Vec = req .chat_history + .clone() .into_iter() .map(|message| message.try_into()) .collect::>, _>>()? @@ -60,58 +65,38 @@ impl CompletionModel { .flatten() .collect(); - // Init full history with preamble (or empty if non-existent) - let mut full_history: Vec = match &completion_request.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - - // Docs appear right after preamble, if they exist - if let Some(docs) = docs { - full_history.extend(docs) - } - - // Chat history and prompt appear in the order they were provided full_history.extend(chat_history); - let tool_choice = completion_request + let tool_choice = req .tool_choice + .clone() .map(crate::providers::openrouter::ToolChoice::try_from) .transpose()?; - let mut request = if completion_request.tools.is_empty() { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - }) - } else { - json!({ - "model": self.model, - "messages": full_history, - "temperature": completion_request.temperature, - "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), - "tool_choice": tool_choice, - }) - }; - - request = if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }; - - Ok(request) + Ok(Self { + model: model.to_string(), + messages: full_history, + temperature: req.temperature, + tools: req + .tools + .clone() + .into_iter() + .map(ToolDefinition::from) + .collect::>(), + tool_choice, + additional_params: req.additional_params, + }) } +} - pub fn new(client: Client, model: impl Into) -> Self { - Self { - client, - model: model.into(), - } - } +#[derive(Clone)] +pub struct CompletionModel { + pub(crate) client: Client, + pub model: String, +} - pub fn with_model(client: Client, model: &str) -> Self { +impl CompletionModel { + pub fn new(client: Client, model: impl Into) -> Self { Self { client, model: model.into(), @@ -138,9 +123,9 @@ where completion_request: completion::CompletionRequest, ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let request = self.create_completion_request(completion_request)?; - let request_messages_json_str = - serde_json::to_string(&request.get("messages").unwrap()).unwrap(); + let request = + XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?; + let request_messages_json_str = serde_json::to_string(&request.messages).unwrap(); let span = if tracing::Span::current().is_disabled() { info_span!( diff --git a/rig-core/src/providers/xai/streaming.rs b/rig-core/src/providers/xai/streaming.rs index 67acd5cdb..e430bcd5f 100644 --- a/rig-core/src/providers/xai/streaming.rs +++ b/rig-core/src/providers/xai/streaming.rs @@ -1,11 +1,10 @@ use crate::completion::{CompletionError, CompletionRequest}; use crate::http_client::HttpClientExt; -use crate::json_utils::merge; +use crate::json_utils::{self}; use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; -use crate::providers::xai::completion::CompletionModel; +use crate::providers::xai::completion::{CompletionModel, XAICompletionRequest}; use crate::streaming::StreamingCompletionResponse; -use serde_json::json; use tracing::{Instrument, info_span}; impl CompletionModel @@ -18,9 +17,15 @@ where ) -> Result, CompletionError> { let preamble = completion_request.preamble.clone(); - let mut request = self.create_completion_request(completion_request)?; + let mut request = + XAICompletionRequest::try_from((self.model.to_string().as_ref(), completion_request))?; - request = merge(request, json!({"stream": true})); + let params = json_utils::merge( + request.additional_params.unwrap_or(serde_json::json!({})), + serde_json::json!({"stream": true }), + ); + + request.additional_params = Some(params); let body = serde_json::to_vec(&request)?; let req = self @@ -42,7 +47,7 @@ where gen_ai.response.model = tracing::field::Empty, gen_ai.usage.output_tokens = tracing::field::Empty, gen_ai.usage.input_tokens = tracing::field::Empty, - gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(), + gen_ai.input.messages = serde_json::to_string(&request.messages)?, gen_ai.output.messages = tracing::field::Empty, ) } else {