Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions rig-bedrock/src/types/assistant_content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ impl TryFrom<RigAssistantContent> for aws_bedrock::ContentBlock {
aws_bedrock::ReasoningContentBlock::ReasoningText(reasoning_text_block),
))
}
AssistantContent::Image(_) => Err(CompletionError::ProviderError(
"AWS Bedrock does not support image content in assistant messages".to_owned(),
)),
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions rig-core/src/completion/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ pub enum AssistantContent {
Text(Text),
ToolCall(ToolCall),
Reasoning(Reasoning),
Image(Image),
}

#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
Expand Down Expand Up @@ -591,6 +592,20 @@ impl AssistantContent {
AssistantContent::Text(text.into().into())
}

/// Helper constructor to make creating assistant image content easier.
pub fn image_base64(
data: impl Into<String>,
media_type: Option<ImageMediaType>,
detail: Option<ImageDetail>,
) -> Self {
AssistantContent::Image(Image {
data: DocumentSourceKind::Base64(data.into()),
media_type,
detail,
additional_params: None,
})
}

/// Helper constructor to make creating assistant tool call content easier.
pub fn tool_call(
id: impl Into<String>,
Expand Down
22 changes: 13 additions & 9 deletions rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,25 +386,29 @@ impl TryFrom<DocumentMediaType> for DocumentFormat {
}
}

impl From<message::AssistantContent> for Content {
fn from(text: message::AssistantContent) -> Self {
impl TryFrom<message::AssistantContent> for Content {
type Error = MessageError;
fn try_from(text: message::AssistantContent) -> Result<Self, Self::Error> {
match text {
message::AssistantContent::Text(message::Text { text }) => Content::Text { text },
message::AssistantContent::Text(message::Text { text }) => Ok(Content::Text { text }),
message::AssistantContent::Image(_) => Err(MessageError::ConversionError(
"Anthropic currently doesn't support images.".to_string(),
)),
message::AssistantContent::ToolCall(message::ToolCall { id, function, .. }) => {
Content::ToolUse {
Ok(Content::ToolUse {
id,
name: function.name,
input: function.arguments,
}
})
}
message::AssistantContent::Reasoning(Reasoning {
reasoning,
signature,
..
}) => Content::Thinking {
}) => Ok(Content::Thinking {
thinking: reasoning.first().cloned().unwrap_or(String::new()),
signature,
},
}),
}
}
}
Expand Down Expand Up @@ -452,7 +456,7 @@ impl TryFrom<message::Message> for Message {
data, media_type, ..
}) => {
let media_type = media_type.ok_or(MessageError::ConversionError(
"Image media type is required for Claude API".into(),
"Image media type is required for Claude API".to_string(),
))?;

let source = match data {
Expand Down Expand Up @@ -515,7 +519,7 @@ impl TryFrom<message::Message> for Message {
},

message::Message::Assistant { content, .. } => Message {
content: content.map(|content| content.into()),
content: content.try_map(|content| content.try_into())?,
role: Role::Assistant,
},
})
Expand Down
58 changes: 33 additions & 25 deletions rig-core/src/providers/cohere/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,32 +393,40 @@ impl TryFrom<message::Message> for Vec<Message> {
message::Message::Assistant { content, .. } => {
let mut text_content = vec![];
let mut tool_calls = vec![];
content.into_iter().for_each(|content| match content {
message::AssistantContent::Text(message::Text { text }) => {
text_content.push(AssistantContent::Text { text });
}
message::AssistantContent::ToolCall(message::ToolCall {
id,
function:
message::ToolFunction {
name, arguments, ..
},
..
}) => {
tool_calls.push(ToolCall {
id: Some(id),
r#type: Some(ToolType::Function),
function: Some(ToolCallFunction {
name,
arguments: serde_json::to_value(arguments).unwrap_or_default(),
}),
});
}
message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
let thinking = reasoning.join("\n");
text_content.push(AssistantContent::Thinking { thinking });

for content in content.into_iter() {
match content {
message::AssistantContent::Text(message::Text { text }) => {
text_content.push(AssistantContent::Text { text });
}
message::AssistantContent::ToolCall(message::ToolCall {
id,
function:
message::ToolFunction {
name, arguments, ..
},
..
}) => {
tool_calls.push(ToolCall {
id: Some(id),
r#type: Some(ToolType::Function),
function: Some(ToolCallFunction {
name,
arguments: serde_json::to_value(arguments).unwrap_or_default(),
}),
});
}
message::AssistantContent::Reasoning(Reasoning { reasoning, .. }) => {
let thinking = reasoning.join("\n");
text_content.push(AssistantContent::Thinking { thinking });
}
message::AssistantContent::Image(_) => {
return Err(message::MessageError::ConversionError(
"Cohere currently doesn't support images.".to_owned(),
));
}
}
});
}

vec![Message::Assistant {
content: text_content,
Expand Down
5 changes: 5 additions & 0 deletions rig-core/src/providers/galadriel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,11 @@ impl TryFrom<message::Message> for Message {
"Galadriel currently doesn't support reasoning.".into(),
));
}
message::AssistantContent::Image(_) => {
return Err(MessageError::ConversionError(
"Galadriel currently doesn't support images.".into(),
));
}
}
}

Expand Down
77 changes: 69 additions & 8 deletions rig-core/src/providers/gemini/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub const GEMINI_1_0_PRO: &str = "gemini-1.0-pro";

use self::gemini_api_types::Schema;
use crate::http_client::HttpClientExt;
use crate::message::Reasoning;
use crate::message::{self, MimeType, Reasoning};
use crate::models;
use crate::providers::gemini::completion::gemini_api_types::{
AdditionalParameters, FunctionCallingMode, ToolConfig,
Expand Down Expand Up @@ -388,6 +388,24 @@ impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<Generat
completion::AssistantContent::text(text)
}
}
PartKind::InlineData(inline_data) => {
let mime_type = message::MediaType::from_mime_type(&inline_data.mime_type);

match mime_type {
Some(message::MediaType::Image(media_type)) => {
message::AssistantContent::image_base64(
&inline_data.data,
Some(media_type),
Some(message::ImageDetail::default()),
)
}
_ => {
return Err(CompletionError::ResponseError(format!(
"Unsupported media type {mime_type:?}"
)));
}
}
}
PartKind::FunctionCall(function_call) => {
completion::AssistantContent::tool_call(
&function_call.name,
Expand Down Expand Up @@ -595,7 +613,10 @@ pub mod gemini_api_types {
},
message::Message::Assistant { content, .. } => Content {
role: Some(Role::Model),
parts: content.into_iter().map(|content| content.into()).collect(),
parts: content
.into_iter()
.map(|content| content.try_into())
.collect::<Result<Vec<_>, _>>()?,
},
})
}
Expand Down Expand Up @@ -1014,20 +1035,47 @@ pub mod gemini_api_types {
}
}

impl From<message::AssistantContent> for Part {
fn from(content: message::AssistantContent) -> Self {
impl TryFrom<message::AssistantContent> for Part {
type Error = message::MessageError;

fn try_from(content: message::AssistantContent) -> Result<Self, Self::Error> {
match content {
message::AssistantContent::Text(message::Text { text }) => text.into(),
message::AssistantContent::ToolCall(tool_call) => tool_call.into(),
message::AssistantContent::Text(message::Text { text }) => Ok(text.into()),
message::AssistantContent::Image(message::Image {
data, media_type, ..
}) => match media_type {
Some(media_type) => match media_type {
message::ImageMediaType::JPEG
| message::ImageMediaType::PNG
| message::ImageMediaType::WEBP
| message::ImageMediaType::HEIC
| message::ImageMediaType::HEIF => {
let part = PartKind::try_from((media_type, data))?;
Ok(Part {
thought: Some(false),
thought_signature: None,
part,
additional_params: None,
})
}
_ => Err(message::MessageError::ConversionError(format!(
"Unsupported image media type {media_type:?}"
))),
},
None => Err(message::MessageError::ConversionError(
"Media type for image is required for Gemini".to_string(),
)),
},
message::AssistantContent::ToolCall(tool_call) => Ok(tool_call.into()),
message::AssistantContent::Reasoning(message::Reasoning { reasoning, .. }) => {
Part {
Ok(Part {
thought: Some(true),
thought_signature: None,
part: PartKind::Text(
reasoning.first().cloned().unwrap_or_else(|| "".to_string()),
),
additional_params: None,
}
})
}
}
}
Expand Down Expand Up @@ -1362,6 +1410,8 @@ pub mod gemini_api_types {
/// Configuration for thinking/reasoning.
#[serde(skip_serializing_if = "Option::is_none")]
pub thinking_config: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_config: Option<ImageConfig>,
}

impl Default for GenerationConfig {
Expand All @@ -1380,6 +1430,7 @@ pub mod gemini_api_types {
response_logprobs: None,
logprobs: None,
thinking_config: None,
image_config: None,
}
}
}
Expand All @@ -1390,6 +1441,16 @@ pub mod gemini_api_types {
pub thinking_budget: u32,
pub include_thoughts: Option<bool>,
}

#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ImageConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub aspect_ratio: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_size: Option<String>,
}

/// The Schema object allows the definition of input and output data types. These types can be objects, but also
/// primitives and arrays. Represents a select subset of an OpenAPI 3.0 schema object.
/// From [Gemini API Reference](https://ai.google.dev/api/caching#Schema)
Expand Down
5 changes: 5 additions & 0 deletions rig-core/src/providers/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ impl TryFrom<message::Message> for Message {
groq_reasoning =
Some(reasoning.first().cloned().unwrap_or(String::new()));
}
message::AssistantContent::Image(_) => {
return Err(MessageError::ConversionError(
"Ollama currently doesn't support images.".into(),
));
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions rig-core/src/providers/huggingface/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ impl TryFrom<message::Message> for Vec<Message> {
message::AssistantContent::Reasoning(_) => {
unimplemented!("Reasoning is not supported on HuggingFace via Rig");
}
message::AssistantContent::Image(_) => {
unimplemented!(
"Image content is not supported on HuggingFace via Rig"
);
}
}
(texts, tools)
},
Expand Down
6 changes: 6 additions & 0 deletions rig-core/src/providers/mistral/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ impl TryFrom<message::Message> for Vec<Message> {
message::AssistantContent::Reasoning(_) => {
unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
}
message::AssistantContent::Image(_) => {
unimplemented!("Image content is not currently supported on Mistral via Rig");
}
}
(texts, tools)
},
Expand Down Expand Up @@ -589,6 +592,9 @@ where
message::AssistantContent::Reasoning(_) => {
unimplemented!("Reasoning is not supported on Mistral via Rig")
}
message::AssistantContent::Image(_) => {
unimplemented!("Image content is not supported on Mistral via Rig")
}
}
}

Expand Down
39 changes: 22 additions & 17 deletions rig-core/src/providers/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,24 +844,29 @@ impl TryFrom<crate::message::Message> for Vec<Message> {
}
InternalMessage::Assistant { content, .. } => {
let mut thinking: Option<String> = None;
let (text_content, tool_calls) = content.into_iter().fold(
(Vec::new(), Vec::new()),
|(mut texts, mut tools), content| {
match content {
crate::message::AssistantContent::Text(text) => texts.push(text.text),
crate::message::AssistantContent::ToolCall(tool_call) => {
tools.push(tool_call)
}
crate::message::AssistantContent::Reasoning(
crate::message::Reasoning { reasoning, .. },
) => {
thinking =
Some(reasoning.first().cloned().unwrap_or(String::new()));
}
let mut text_content = Vec::new();
let mut tool_calls = Vec::new();

for content in content.into_iter() {
match content {
crate::message::AssistantContent::Text(text) => {
text_content.push(text.text)
}
(texts, tools)
},
);
crate::message::AssistantContent::ToolCall(tool_call) => {
tool_calls.push(tool_call)
}
crate::message::AssistantContent::Reasoning(
crate::message::Reasoning { reasoning, .. },
) => {
thinking = Some(reasoning.first().cloned().unwrap_or(String::new()));
}
crate::message::AssistantContent::Image(_) => {
return Err(crate::message::MessageError::ConversionError(
"Ollama currently doesn't support images.".into(),
));
}
}
}

// `OneOrMany` ensures at least one `AssistantContent::Text` or `ToolCall` exists,
// so either `content` or `tool_calls` will have some content.
Expand Down
Loading