Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,27 @@
let
overlays = [ (import rust-overlay) ];
pkgs = import nixpkgs {
inherit system overlays;
};
inherit system overlays;
};
rustToolchain = pkgs.rust-bin.stable."1.90.0".default.override {
targets = [ "wasm32-unknown-unknown" ];
};
in
{
devShells.default = with pkgs; mkShell {
buildInputs = [
pkg-config
cmake
just

openssl
sqlite
postgresql
protobuf

rust-bin.stable."1.90.0".default
rustToolchain
wasm-bindgen-cli
wasm-pack
];

OPENSSL_DEV = openssl.dev;
Expand Down
1 change: 1 addition & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ worker = { workspace = true, optional = true }
rmcp = { version = "0.6", optional = true, features = ["client"] }
reqwest-eventsource = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
http = "1.3.1"
tracing-futures = { version = "0.2.5", features = ["futures-03"] }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/agent_with_galadriel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async fn main() -> Result<(), anyhow::Error> {
if let Some(fine_tune_api_key) = fine_tune_api_key.as_deref() {
builder = builder.fine_tune_api_key(fine_tune_api_key);
}
let client = builder.build().expect("Failed to build client");
let client = builder.build();

// Create agent with a single context prompt
let comedian_agent = client
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/pdf_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async fn main() -> Result<()> {
// Initialize Ollama client
let client = openai::Client::builder("ollama")
.base_url("http://localhost:11434/v1")
.build()?;
.build();

// Load PDFs using Rig's built-in PDF loader
let documents_dir = std::env::current_dir()?.join("rig-core/examples/documents");
Expand Down
2 changes: 1 addition & 1 deletion rig-core/examples/vector_search_ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async fn main() -> Result<(), anyhow::Error> {
// Create ollama client
let client = providers::ollama::Client::builder()
.base_url("http://localhost:11434")
.build()?;
.build();
let embedding_model = client.embedding_model("nomic-embed-text");

let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
Expand Down
18 changes: 11 additions & 7 deletions rig-core/src/agent/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
streaming::{StreamingChat, StreamingCompletion, StreamingPrompt},
tool::ToolSet,
vector_store::{VectorStoreError, request::VectorSearchRequest},
wasm_compat::WasmCompatSend,
};
use futures::{StreamExt, TryStreamExt, stream};
use std::{collections::HashMap, sync::Arc};
Expand Down Expand Up @@ -85,7 +86,7 @@ where
{
async fn completion(
&self,
prompt: impl Into<Message> + Send,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
let prompt = prompt.into();
Expand Down Expand Up @@ -228,7 +229,7 @@ where
{
fn prompt(
&self,
prompt: impl Into<Message> + Send,
prompt: impl Into<Message> + WasmCompatSend,
) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
PromptRequest::new(self, prompt)
}
Expand All @@ -242,7 +243,7 @@ where
#[tracing::instrument(skip(self, prompt), fields(agent_name = self.name()))]
fn prompt(
&self,
prompt: impl Into<Message> + Send,
prompt: impl Into<Message> + WasmCompatSend,
) -> PromptRequest<'_, prompt_request::Standard, M, ()> {
PromptRequest::new(*self, prompt)
}
Expand All @@ -256,7 +257,7 @@ where
#[tracing::instrument(skip(self, prompt, chat_history), fields(agent_name = self.name()))]
async fn chat(
&self,
prompt: impl Into<Message> + Send,
prompt: impl Into<Message> + WasmCompatSend,
mut chat_history: Vec<Message>,
) -> Result<String, PromptError> {
PromptRequest::new(self, prompt)
Expand All @@ -271,7 +272,7 @@ where
{
async fn stream_completion(
&self,
prompt: impl Into<Message> + Send,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> Result<CompletionRequestBuilder<M>, CompletionError> {
// Reuse the existing completion implementation to build the request
Expand All @@ -285,7 +286,10 @@ where
M: CompletionModel + 'static,
M::StreamingResponse: GetTokenUsage,
{
fn stream_prompt(&self, prompt: impl Into<Message> + Send) -> StreamingPromptRequest<M, ()> {
fn stream_prompt(
&self,
prompt: impl Into<Message> + WasmCompatSend,
) -> StreamingPromptRequest<M, ()> {
let arc = Arc::new(self.clone());
StreamingPromptRequest::new(arc, prompt)
}
Expand All @@ -298,7 +302,7 @@ where
{
fn stream_chat(
&self,
prompt: impl Into<Message> + Send,
prompt: impl Into<Message> + WasmCompatSend,
chat_history: Vec<Message>,
) -> StreamingPromptRequest<M, ()> {
let arc = Arc::new(self.clone());
Expand Down
25 changes: 15 additions & 10 deletions rig-core/src/agent/prompt_request/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ use std::{
};
use tracing::{Instrument, span::Id};

use futures::{FutureExt, StreamExt, future::BoxFuture, stream};
use futures::{StreamExt, stream};
use tracing::info_span;

use crate::{
OneOrMany,
completion::{Completion, CompletionError, CompletionModel, Message, PromptError, Usage},
message::{AssistantContent, UserContent},
tool::ToolSetError,
wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
};

use super::Agent;
Expand Down Expand Up @@ -136,7 +137,7 @@ where

// dead code allowed because of functions being left empty to allow for users to not have to implement every single function
/// Trait for per-request hooks to observe tool call events.
pub trait PromptHook<M>: Clone + Send + Sync
pub trait PromptHook<M>: Clone + WasmCompatSend + WasmCompatSync
where
M: CompletionModel,
{
Expand All @@ -146,7 +147,7 @@ where
&self,
prompt: &Message,
history: &[Message],
) -> impl Future<Output = ()> + Send {
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}

Expand All @@ -156,13 +157,17 @@ where
&self,
prompt: &Message,
response: &crate::completion::CompletionResponse<M::Response>,
) -> impl Future<Output = ()> + Send {
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}

#[allow(unused_variables)]
/// Called before a tool is invoked.
fn on_tool_call(&self, tool_name: &str, args: &str) -> impl Future<Output = ()> + Send {
fn on_tool_call(
&self,
tool_name: &str,
args: &str,
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}

Expand All @@ -173,7 +178,7 @@ where
tool_name: &str,
args: &str,
result: &str,
) -> impl Future<Output = ()> + Send {
) -> impl Future<Output = ()> + WasmCompatSend {
async {}
}
}
Expand All @@ -189,10 +194,10 @@ where
P: PromptHook<M> + 'static,
{
type Output = Result<String, PromptError>;
type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent

fn into_future(self) -> Self::IntoFuture {
self.send().boxed()
Box::pin(self.send())
}
}

Expand All @@ -202,10 +207,10 @@ where
P: PromptHook<M> + 'static,
{
type Output = Result<PromptResponse, PromptError>;
type IntoFuture = BoxFuture<'a, Self::Output>; // This future should not outlive the agent
type IntoFuture = WasmBoxedFuture<'a, Self::Output>; // This future should not outlive the agent

fn into_future(self) -> Self::IntoFuture {
self.send().boxed()
Box::pin(self.send())
}
}

Expand Down
7 changes: 4 additions & 3 deletions rig-core/src/agent/prompt_request/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
completion::GetTokenUsage,
message::{AssistantContent, Reasoning, ToolResultContent, UserContent},
streaming::{StreamedAssistantContent, StreamingCompletion},
wasm_compat::{WasmBoxedFuture, WasmCompatSend},
};
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -110,7 +111,7 @@ where
impl<M, P> StreamingPromptRequest<M, P>
where
M: CompletionModel + 'static,
<M as CompletionModel>::StreamingResponse: Send + GetTokenUsage,
<M as CompletionModel>::StreamingResponse: WasmCompatSend + GetTokenUsage,
P: StreamingPromptHook<M>,
{
/// Create a new PromptRequest with the given prompt and model
Expand Down Expand Up @@ -395,11 +396,11 @@ where
impl<M, P> IntoFuture for StreamingPromptRequest<M, P>
where
M: CompletionModel + 'static,
<M as CompletionModel>::StreamingResponse: Send,
<M as CompletionModel>::StreamingResponse: WasmCompatSend,
P: StreamingPromptHook<M> + 'static,
{
type Output = StreamingResult<M::StreamingResponse>; // what `.await` returns
type IntoFuture = Pin<Box<dyn futures::Future<Output = Self::Output> + Send>>;
type IntoFuture = WasmBoxedFuture<'static, Self::Output>;

fn into_future(self) -> Self::IntoFuture {
// Wrap send() in a future, because send() returns a stream immediately
Expand Down
4 changes: 2 additions & 2 deletions rig-core/src/audio_generation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Everything related to audio generation (ie, Text To Speech).
//! Rig abstracts over a number of different providers using the [AudioGenerationModel] trait.
use crate::client::audio_generation::AudioGenerationModelHandle;
use crate::{client::audio_generation::AudioGenerationModelHandle, http_client};
use futures::future::BoxFuture;
use serde_json::Value;
use std::sync::Arc;
Expand All @@ -10,7 +10,7 @@ use thiserror::Error;
pub enum AudioGenerationError {
/// Http error (e.g.: connection error, timeout, etc.)
#[error("HttpError: {0}")]
HttpError(#[from] reqwest::Error),
HttpError(#[from] http_client::Error),

/// Json error (e.g.: serialization, deserialization)
#[error("JsonError: {0}")]
Expand Down
3 changes: 2 additions & 1 deletion rig-core/src/cli_chatbot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
completion::{Chat, CompletionError, CompletionModel, PromptError, Usage},
message::Message,
streaming::{StreamedAssistantContent, StreamingPrompt},
wasm_compat::WasmCompatSend,
};
use futures::StreamExt;
use std::io::{self, Write};
Expand Down Expand Up @@ -60,7 +61,7 @@ where

impl<M> CliChat for AgentImpl<M>
where
M: CompletionModel + 'static,
M: CompletionModel + WasmCompatSend + 'static,
{
async fn request(
&mut self,
Expand Down
Loading