Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
toolchain: ${{ vars.RUST_VERSION }}

- name: Run cargo check wasm target
run: cargo check --package rig-core --features worker --target wasm32-unknown-unknown
run: cargo check --package rig-core --features wasm --target wasm32-unknown-unknown

clippy:
name: stable / clippy
Expand Down
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ tokio-test = "0.4.4"
tracing = "0.1.41"
tracing-subscriber = "0.3.19"
uuid = "1.17.0"
worker = "0.6"
zerocopy = "0.8.26"

[workspace.metadata.cargo-autoinherit]
Expand Down
3 changes: 1 addition & 2 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
url = { workspace = true }
worker = { workspace = true, optional = true }
rmcp = { version = "0.8", optional = true, features = ["client"] }
tokio = { workspace = true, features = ["rt", "sync"] }
http = "1.3.1"
Expand Down Expand Up @@ -89,7 +88,7 @@ discord-bot = ["dep:serenity"]
pdf = ["dep:lopdf"]
epub = ["dep:epub", "dep:quick-xml"]
rayon = ["dep:rayon"]
worker = ["dep:worker", "dep:wasm-bindgen-futures", "futures-timer/wasm-bindgen"]
wasm = ["dep:wasm-bindgen-futures", "futures-timer/wasm-bindgen"]
rmcp = ["dep:rmcp"]
socks = ["reqwest/socks"]
reqwest-tls = ["reqwest/default"]
Expand Down
12 changes: 6 additions & 6 deletions rig-core/src/agent/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ where
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Tool server handle
Expand Down Expand Up @@ -248,10 +248,10 @@ where
pub fn dynamic_tools(
self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> AgentBuilderSimple<M> {
let thing: Box<dyn VectorStoreIndexDyn + Send + Sync + 'static> = Box::new(dynamic_tools);
let thing: Box<dyn VectorStoreIndexDyn + 'static> = Box::new(dynamic_tools);
let dynamic_tools = vec![(sample, thing)];

AgentBuilderSimple {
Expand Down Expand Up @@ -355,9 +355,9 @@ where
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn + Send + Sync>)>,
dynamic_tools: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Temperature of the model
temperature: Option<f64>,
/// Actual tool implementations
Expand Down Expand Up @@ -480,7 +480,7 @@ where
pub fn dynamic_tools(
mut self,
sample: usize,
dynamic_tools: impl VectorStoreIndexDyn + Send + Sync + 'static,
dynamic_tools: impl VectorStoreIndexDyn + 'static,
toolset: ToolSet,
) -> Self {
self.dynamic_tools.push((sample, Box::new(dynamic_tools)));
Expand Down
10 changes: 2 additions & 8 deletions rig-core/src/agent/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,8 @@ use tokio::sync::RwLock;

const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";

pub type DynamicContextStore = Arc<
RwLock<
Vec<(
usize,
Box<dyn crate::vector_store::VectorStoreIndexDyn + Send + Sync>,
)>,
>,
>;
pub type DynamicContextStore =
Arc<RwLock<Vec<(usize, Box<dyn crate::vector_store::VectorStoreIndexDyn>)>>>;

/// Struct representing an LLM agent. An agent is an LLM model combined with a preamble
/// (i.e.: system prompt) and a static set of context documents and tools.
Expand Down
5 changes: 2 additions & 3 deletions rig-core/src/agent/prompt_request/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ use crate::{
tool::ToolSetError,
};

#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
pub type StreamingResult<R> =
Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>> + Send>>;

#[cfg(target_arch = "wasm32")]
#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
pub type StreamingResult<R> =
Pin<Box<dyn Stream<Item = Result<MultiTurnStreamItem<R>, StreamingError>>>>;

Expand Down Expand Up @@ -153,7 +153,6 @@ where
}
}

#[cfg_attr(feature = "worker", worker::send)]
async fn send(self) -> StreamingResult<M::StreamingResponse> {
let agent_span = if tracing::Span::current().is_disabled() {
info_span!(
Expand Down
12 changes: 8 additions & 4 deletions rig-core/src/audio_generation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Everything related to audio generation (ie, Text To Speech).
//! Rig abstracts over a number of different providers using the [AudioGenerationModel] trait.
use crate::{client::audio_generation::AudioGenerationModelHandle, http_client};
use crate::{
client::audio_generation::AudioGenerationModelHandle,
http_client,
wasm_compat::{WasmCompatSend, WasmCompatSync},
};
use futures::future::BoxFuture;
use serde_json::Value;
use std::sync::Arc;
Expand Down Expand Up @@ -46,15 +50,15 @@ where
voice: &str,
) -> impl std::future::Future<
Output = Result<AudioGenerationRequestBuilder<M>, AudioGenerationError>,
> + Send;
> + WasmCompatSend;
}

pub struct AudioGenerationResponse<T> {
pub audio: Vec<u8>,
pub response: T,
}

pub trait AudioGenerationModel: Clone + Send + Sync {
pub trait AudioGenerationModel: Clone + WasmCompatSend + WasmCompatSync {
type Response: Send + Sync;

fn audio_generation(
Expand All @@ -69,7 +73,7 @@ pub trait AudioGenerationModel: Clone + Send + Sync {
}
}

pub trait AudioGenerationModelDyn: Send + Sync {
pub trait AudioGenerationModelDyn: WasmCompatSend + WasmCompatSync {
fn audio_generation(
&self,
request: AudioGenerationRequest,
Expand Down
4 changes: 2 additions & 2 deletions rig-core/src/completion/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ pub enum CompletionError {
#[error("UrlError: {0}")]
UrlError(#[from] url::ParseError),

#[cfg(not(target_family = "wasm"))]
#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
/// Error building the completion request
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + Send + Sync + 'static>),

#[cfg(target_family = "wasm")]
#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
/// Error building the completion request
#[error("RequestError: {0}")]
RequestError(#[from] Box<dyn std::error::Error + 'static>),
Expand Down
4 changes: 2 additions & 2 deletions rig-core/src/embeddings/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ pub enum EmbeddingError {
#[error("UrlError: {0}")]
UrlError(#[from] url::ParseError),

#[cfg(not(target_family = "wasm"))]
#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
/// Error processing the document for embedding
#[error("DocumentError: {0}")]
DocumentError(Box<dyn std::error::Error + Send + Sync + 'static>),

#[cfg(target_family = "wasm")]
#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
/// Error processing the document for embedding
#[error("DocumentError: {0}")]
DocumentError(Box<dyn std::error::Error + 'static>),
Expand Down
72 changes: 55 additions & 17 deletions rig-core/src/evals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
completion::CompletionModel,
embeddings::EmbeddingModel,
extractor::{Extractor, ExtractorBuilder},
wasm_compat::{WasmCompatSend, WasmCompatSync},
};

/// Evaluation errors.
Expand All @@ -22,7 +23,7 @@ pub enum EvalError {
Custom(String),
}

/// The outcome of an evaluation (ie, sending an input to an LLM which then gets tested against a set of criteria).
/// The outcome of an evaluation (ie, WasmCompatSending an input to an LLM which then gets tested against a set of criteria).
/// Invalid results due to things like functions returning errors should be encoded as invalid evaluation outcomes.
#[derive(Deserialize, Serialize, Clone, Debug)]
#[serde(tag = "outcome", content = "data")]
Expand Down Expand Up @@ -58,20 +59,20 @@ impl<Output> EvalOutcome<Output> {
/// - Invalid (the output was unable to be retrieved due to an external failure like an API call fail)
pub trait Eval<Output>
where
Output: for<'a> Deserialize<'a> + Serialize + Clone + Send + Sync,
Self: Sized + Send + Sync + 'static,
Output: for<'a> Deserialize<'a> + Serialize + Clone + WasmCompatSend + WasmCompatSync,
Self: Sized + WasmCompatSend + WasmCompatSync + 'static,
{
fn eval(&self, input: String) -> impl Future<Output = EvalOutcome<Output>> + Send;
fn eval(&self, input: String) -> impl Future<Output = EvalOutcome<Output>> + WasmCompatSend;

/// Send a bunch of inputs to be evaluated all in one call.
/// WasmCompatSend a bunch of inputs to be evaluated all in one call.
/// You can set the concurrency limit to help alleviate issues
/// with model provider API limits, as sending requests too quickly may
/// with model provider API limits, as WasmCompatSending requests too quickly may
/// result in throttling or temporary request refusal.
fn eval_batch(
&self,
input: Vec<String>,
concurrency_limit: usize,
) -> impl Future<Output = Vec<EvalOutcome<Output>>> + Send {
) -> impl Future<Output = Vec<EvalOutcome<Output>>> + WasmCompatSend {
use futures::StreamExt;
async move {
let thing: Vec<EvalOutcome<Output>> = futures::stream::iter(input)
Expand Down Expand Up @@ -203,7 +204,12 @@ where
pub struct LlmJudgeMetric<M, T>
where
M: CompletionModel,
T: Judgment + Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
T: Judgment
+ WasmCompatSend
+ WasmCompatSync
+ JsonSchema
+ Serialize
+ for<'a> Deserialize<'a>,
{
ext: Extractor<M, T>,
}
Expand All @@ -213,38 +219,46 @@ where
pub struct LlmJudgeMetricWithFn<M, T>
where
M: CompletionModel,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
T: WasmCompatSend + WasmCompatSync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
{
ext: Extractor<M, T>,
#[cfg(not(not(all(feature = "wasm", target_arch = "wasm32"))))]
evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,

#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
}

pub struct LlmJudgeBuilder<M, T>
where
M: CompletionModel,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
T: WasmCompatSend + WasmCompatSync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
{
ext: ExtractorBuilder<M, T>,
}

pub struct LlmJudgeBuilderWithFn<M, T>
where
M: CompletionModel,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
T: WasmCompatSend + WasmCompatSync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
{
ext: ExtractorBuilder<M, T>,
#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
evaluator: Box<dyn Fn(&T) -> bool + Send + Sync>,
}

impl<M, T> LlmJudgeBuilder<M, T>
where
M: CompletionModel,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
T: WasmCompatSend + WasmCompatSync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
{
pub fn new(ext: ExtractorBuilder<M, T>) -> Self {
Self { ext }
}

#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
pub fn with_fn<F>(self, f: F) -> LlmJudgeBuilderWithFn<M, T>
where
F: Fn(&T) -> bool + Send + Sync + 'static,
Expand All @@ -255,6 +269,17 @@ where
}
}

#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
pub fn with_fn<F>(self, f: F) -> LlmJudgeBuilderWithFn<M, T>
where
F: Fn(&T) -> bool + 'static,
{
LlmJudgeBuilderWithFn {
ext: self.ext,
evaluator: Box::new(f),
}
}

pub fn build(self) -> LlmJudgeMetric<M, T>
where
T: Judgment + 'static,
Expand All @@ -272,11 +297,11 @@ where
impl<M, T> LlmJudgeBuilderWithFn<M, T>
where
M: CompletionModel,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
T: WasmCompatSend + WasmCompatSync + JsonSchema + Serialize + for<'a> Deserialize<'a> + 'static,
{
pub fn with_fn<F2>(mut self, f: F2) -> Self
where
F2: Fn(&T) -> bool + Send + Sync + 'static,
F2: Fn(&T) -> bool + WasmCompatSend + WasmCompatSync + 'static,
{
self.evaluator = Box::new(f);
self
Expand Down Expand Up @@ -306,7 +331,14 @@ pub trait Judgment {
impl<M, T> Eval<T> for LlmJudgeMetric<M, T>
where
M: CompletionModel + 'static,
T: Judgment + Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + Clone + 'static,
T: Judgment
+ WasmCompatSend
+ WasmCompatSync
+ JsonSchema
+ Serialize
+ for<'a> Deserialize<'a>
+ Clone
+ 'static,
{
async fn eval(&self, input: String) -> EvalOutcome<T> {
match self.ext.extract(input).await {
Expand All @@ -325,7 +357,13 @@ where
impl<M, T> Eval<T> for LlmJudgeMetricWithFn<M, T>
where
M: CompletionModel + 'static,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a> + Clone + 'static,
T: WasmCompatSend
+ WasmCompatSync
+ JsonSchema
+ Serialize
+ for<'a> Deserialize<'a>
+ Clone
+ 'static,
{
async fn eval(&self, input: String) -> EvalOutcome<T> {
match self.ext.extract(input).await {
Expand All @@ -344,7 +382,7 @@ where
impl<M, T> From<ExtractorBuilder<M, T>> for LlmJudgeBuilder<M, T>
where
M: CompletionModel,
T: Send + Sync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
T: WasmCompatSend + WasmCompatSync + JsonSchema + Serialize + for<'a> Deserialize<'a>,
{
fn from(ext: ExtractorBuilder<M, T>) -> Self {
Self::new(ext)
Expand Down
Loading