Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 75 additions & 68 deletions rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
OneOrMany,
completion::{self, CompletionError, GetTokenUsage},
http_client::HttpClientExt,
json_utils,
message::{self, DocumentMediaType, DocumentSourceKind, MessageError, Reasoning},
one_or_many::string_or_one_or_many,
telemetry::{ProviderResponseExt, SpanCombinator},
Expand All @@ -17,7 +16,6 @@ use crate::completion::CompletionRequest;
use crate::providers::anthropic::streaming::StreamingCompletionResponse;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{Instrument, info_span};

// ================================================================
Expand Down Expand Up @@ -721,6 +719,68 @@ impl TryFrom<message::ToolChoice> for ToolChoice {
Ok(res)
}
}

#[derive(Debug, Deserialize, Serialize)]
struct AnthropicCompletionRequest {
model: String,
messages: Vec<Message>,
max_tokens: u64,
system: String,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<ToolChoice>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<ToolDefinition>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
additional_params: Option<serde_json::Value>,
}

impl TryFrom<(&str, CompletionRequest)> for AnthropicCompletionRequest {
type Error = CompletionError;

fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
// Check if max_tokens is set, required for Anthropic
let Some(max_tokens) = req.max_tokens else {
return Err(CompletionError::RequestError(
"`max_tokens` must be set for Anthropic".into(),
));
};

let mut full_history = vec![];
if let Some(docs) = req.normalized_documents() {
full_history.push(docs);
}
full_history.extend(req.chat_history);

let messages = full_history
.into_iter()
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;

let tools = req
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>();

Ok(Self {
model: model.to_string(),
messages,
max_tokens,
system: req.preamble.unwrap_or_default(),
temperature: req.temperature,
tool_choice: req.tool_choice.and_then(|x| ToolChoice::try_from(x).ok()),
tools,
additional_params: req.additional_params,
})
}
}

impl<T> completion::CompletionModel for CompletionModel<T>
where
T: HttpClientExt + Clone + Default + WasmCompatSend + WasmCompatSync + 'static,
Expand All @@ -736,15 +796,15 @@ where
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
mut completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let span = if tracing::Span::current().is_disabled() {
info_span!(
target: "rig::completions",
"chat",
gen_ai.operation.name = "chat",
gen_ai.provider.name = "anthropic",
gen_ai.request.model = self.model,
gen_ai.request.model = &self.model,
gen_ai.system_instructions = &completion_request.preamble,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
Expand All @@ -756,75 +816,21 @@ where
} else {
tracing::Span::current()
};
// Note: Ideally we'd introduce provider-specific Request models to handle the
// specific requirements of each provider. For now, we just manually check while
// building the request as a raw JSON document.

// Check if max_tokens is set, required for Anthropic
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
tokens
} else if let Some(tokens) = self.default_max_tokens {
tokens
} else {
return Err(CompletionError::RequestError(
"`max_tokens` must be set for Anthropic".into(),
));
};

let mut full_history = vec![];
if let Some(docs) = completion_request.normalized_documents() {
full_history.push(docs);
}
full_history.extend(completion_request.chat_history);
span.record_model_input(&full_history);

let full_history = full_history
.into_iter()
.map(Message::try_from)
.collect::<Result<Vec<Message>, _>>()?;

let mut request = json!({
"model": self.model,
"messages": full_history,
"max_tokens": max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
});

if let Some(temperature) = completion_request.temperature {
json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
}

let tool_choice = if let Some(tool_choice) = completion_request.tool_choice {
Some(ToolChoice::try_from(tool_choice)?)
} else {
None
};

if !completion_request.tools.is_empty() {
let mut tools_json = json!({
"tools": completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>(),
});

// Only include tool_choice if it's explicitly set (not None)
// When omitted, Anthropic defaults to "auto"
if let Some(tc) = tool_choice {
tools_json["tool_choice"] = serde_json::to_value(tc)?;
if completion_request.max_tokens.is_none() {
if let Some(tokens) = self.default_max_tokens {
completion_request.max_tokens = Some(tokens);
} else {
return Err(CompletionError::RequestError(
"`max_tokens` must be set for Anthropic".into(),
));
}

json_utils::merge_inplace(&mut request, tools_json);
}

if let Some(ref params) = completion_request.additional_params {
json_utils::merge_inplace(&mut request, params.clone())
}
let request =
AnthropicCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
span.record_model_input(&request.messages);

async move {
let request: Vec<u8> = serde_json::to_vec(&request)?;
Expand Down Expand Up @@ -909,6 +915,7 @@ enum ApiResponse<T> {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use serde_path_to_error::deserialize;

#[test]
Expand Down
135 changes: 75 additions & 60 deletions rig-core/src/providers/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use crate::client::{
};
use crate::completion::GetTokenUsage;
use crate::http_client::{self, HttpClientExt, bearer_auth_header};
use crate::json_utils::merge;
use crate::streaming::StreamingCompletionResponse;
use crate::transcription::TranscriptionError;
use crate::{
Expand All @@ -32,7 +31,7 @@ use crate::{
};
use bytes::Bytes;
use reqwest::multipart::Part;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::json;
// ================================================================
// Main Azure OpenAI Client
Expand Down Expand Up @@ -544,42 +543,44 @@ pub const GPT_35_TURBO_INSTRUCT: &str = "gpt-3.5-turbo-instruct";
/// `gpt-3.5-turbo-16k` completion model
pub const GPT_35_TURBO_16K: &str = "gpt-3.5-turbo-16k";

#[derive(Clone)]
pub struct CompletionModel<T = reqwest::Client> {
client: Client<T>,
/// Name of the model (e.g.: gpt-4o-mini)
pub model: String,
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct AzureOpenAICompletionRequest {
model: String,
pub messages: Vec<openai::Message>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<openai::ToolDefinition>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
tool_choice: Option<crate::providers::openrouter::ToolChoice>,
#[serde(flatten, skip_serializing_if = "Option::is_none")]
pub additional_params: Option<serde_json::Value>,
}

impl<T> CompletionModel<T> {
pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
Self {
client,
model: model.into(),
}
}
impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest {
type Error = CompletionError;

pub fn with_model(client: Client<T>, model: &str) -> Self {
Self {
client,
model: model.into(),
fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
//FIXME: Must fix!
if req.tool_choice.is_some() {
tracing::warn!(
"Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25."
);
}
}

fn create_completion_request(
&self,
completion_request: CompletionRequest,
) -> Result<serde_json::Value, CompletionError> {
let mut full_history: Vec<openai::Message> = match &completion_request.preamble {
let mut full_history: Vec<openai::Message> = match &req.preamble {
Some(preamble) => vec![openai::Message::system(preamble)],
None => vec![],
};
if let Some(docs) = completion_request.normalized_documents() {

if let Some(docs) = req.normalized_documents() {
let docs: Vec<openai::Message> = docs.try_into()?;
full_history.extend(docs);
}
let chat_history: Vec<openai::Message> = completion_request

let chat_history: Vec<openai::Message> = req
.chat_history
.clone()
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Vec<openai::Message>>, _>>()?
Expand All @@ -589,29 +590,41 @@ impl<T> CompletionModel<T> {

full_history.extend(chat_history);

let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(openai::ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};
let tool_choice = req
.tool_choice
.clone()
.map(crate::providers::openrouter::ToolChoice::try_from)
.transpose()?;

Ok(Self {
model: model.to_string(),
messages: full_history,
temperature: req.temperature,
tools: req
.tools
.clone()
.into_iter()
.map(openai::ToolDefinition::from)
.collect::<Vec<_>>(),
tool_choice,
additional_params: req.additional_params,
})
}
}

let request = if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
};
#[derive(Clone)]
pub struct CompletionModel<T = reqwest::Client> {
client: Client<T>,
/// Name of the model (e.g.: gpt-4o-mini)
pub model: String,
}

Ok(request)
impl<T> CompletionModel<T> {
pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
Self {
client,
model: model.into(),
}
}
}

Expand Down Expand Up @@ -650,12 +663,11 @@ where
} else {
tracing::Span::current()
};
let request = self.create_completion_request(completion_request)?;
span.record_model_input(
&request
.get("messages")
.expect("Converting JSON should not fail"),
);

let request =
AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;

span.record_model_input(&request.messages);
let body = serde_json::to_vec(&request)?;

let req = self
Expand Down Expand Up @@ -707,16 +719,19 @@ where
#[cfg_attr(feature = "worker", worker::send)]
async fn stream(
&self,
request: CompletionRequest,
completion_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let preamble = request.preamble.clone();
let mut request = self.create_completion_request(request)?;
let preamble = completion_request.preamble.clone();
let mut request =
AzureOpenAICompletionRequest::try_from((self.model.as_ref(), completion_request))?;

request = merge(
request,
json!({"stream": true, "stream_options": {"include_usage": true}}),
let params = json_utils::merge(
request.additional_params.unwrap_or(serde_json::json!({})),
serde_json::json!({"stream": true, "stream_options": {"include_usage": true} }),
);

request.additional_params = Some(params);

let body = serde_json::to_vec(&request)?;

let req = self
Expand All @@ -738,7 +753,7 @@ where
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.input.messages = serde_json::to_string(&request.get("messages").unwrap()).unwrap(),
gen_ai.input.messages = serde_json::to_string(&request.messages)?,
gen_ai.output.messages = tracing::field::Empty,
)
} else {
Expand Down
Loading