diff --git a/rig-core/examples/lmstudio_streaming.rs b/rig-core/examples/lmstudio_streaming.rs new file mode 100644 index 000000000..457fcc5b6 --- /dev/null +++ b/rig-core/examples/lmstudio_streaming.rs @@ -0,0 +1,29 @@ +use rig::agent::stream_to_stdout; +use rig::prelude::*; +use rig::providers::{lmstudio, openai}; +use rig::streaming::StreamingPrompt; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Uncomment tracing for debugging + tracing_subscriber::fmt().init(); + + // Create streaming agent with a single context prompt + let agent = lmstudio::Client::from_env() + .agent(openai::GPT_4O) + .preamble("Be precise and concise.") + .temperature(0.5) + .build(); + + // Stream the response and print chunks as they arrive + let mut stream = agent + .stream_prompt("When and where and what type is the next solar eclipse?") + .await; + + let res = stream_to_stdout(&mut stream).await?; + + println!("Token usage response: {usage:?}", usage = res.usage()); + println!("Final text response: {message:?}", message = res.response()); + + Ok(()) +} diff --git a/rig-core/src/providers/groq.rs b/rig-core/src/providers/groq.rs index 282642bd2..8f72764ab 100644 --- a/rig-core/src/providers/groq.rs +++ b/rig-core/src/providers/groq.rs @@ -628,7 +628,8 @@ pub async fn send_compatible_streaming_request( let mut final_usage = Usage { prompt_tokens: 0, - total_tokens: 0 + total_tokens: 0, + extra: serde_json::Map::new(), }; let mut partial_data = None; diff --git a/rig-core/src/providers/lmstudio.rs b/rig-core/src/providers/lmstudio.rs new file mode 100644 index 000000000..8f37a12cb --- /dev/null +++ b/rig-core/src/providers/lmstudio.rs @@ -0,0 +1,90 @@ +use crate::prelude::CompletionClient; +use crate::prelude::EmbeddingsClient; +use crate::prelude::ProviderClient; +use crate::prelude::TranscriptionClient; +use crate::providers::openai; +pub use openai::completion::*; + +const LMSTUDIO_API_BASE_URL: &str = "http://localhost:8080/v1"; + +/// A client for the LM Studio API. +#[derive(Clone, Debug)] +pub struct Client { + inner: openai::Client, +} + +impl ProviderClient for Client { + fn from_env() -> Self { + let base_url = std::env::var("LMSTUDIO_API_BASE") + .unwrap_or_else(|_| LMSTUDIO_API_BASE_URL.to_string()); + let api_key = std::env::var("LMSTUDIO_API_KEY").unwrap_or_else(|_| "lmstudio".to_string()); + + let inner = openai::Client::builder(&api_key) + .base_url(&base_url) + .build() + .expect("Failed to build LM Studio client"); + + Self { inner } + } + + fn from_val(input: crate::client::ProviderValue) -> Self { + let crate::client::ProviderValue::Simple(api_key) = input else { + panic!("Incorrect provider value type") + }; + let base_url = std::env::var("LMSTUDIO_API_BASE") + .unwrap_or_else(|_| LMSTUDIO_API_BASE_URL.to_string()); + + let inner = openai::Client::builder(&api_key) + .base_url(&base_url) + .build() + .expect("Failed to build LM Studio client"); + + Self { inner } + } +} + +impl CompletionClient for Client { + type CompletionModel = openai::responses_api::ResponsesCompletionModel; + + fn completion_model(&self, model: &str) -> Self::CompletionModel { + self.inner.completion_model(model) + } +} + +impl EmbeddingsClient for Client { + type EmbeddingModel = openai::embedding::EmbeddingModel; + + fn embedding_model(&self, model: &str) -> Self::EmbeddingModel { + self.inner.embedding_model(model) + } + + fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel { + self.inner.embedding_model_with_ndims(model, ndims) + } +} + +impl TranscriptionClient for Client { + type TranscriptionModel = openai::transcription::TranscriptionModel; + + fn transcription_model(&self, model: &str) -> Self::TranscriptionModel { + self.inner.transcription_model(model) + } +} + +#[cfg(feature = "image")] +impl crate::client::ImageGenerationClient for Client { + type ImageGenerationModel = openai::image_generation::ImageGenerationModel; + + fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel { + self.inner.image_generation_model(model) + } +} + +#[cfg(feature = "audio")] +impl crate::client::AudioGenerationClient for Client { + type AudioGenerationModel = openai::audio_generation::AudioGenerationModel; + + fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel { + self.inner.audio_generation_model(model) + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 53ca700f5..068266cc0 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -54,6 +54,7 @@ pub mod gemini; pub mod groq; pub mod huggingface; pub mod hyperbolic; +pub mod lmstudio; pub mod mira; pub mod mistral; pub mod moonshot; diff --git a/rig-core/src/providers/openai/client.rs b/rig-core/src/providers/openai/client.rs index 820b4d3cd..b1ee24080 100644 --- a/rig-core/src/providers/openai/client.rs +++ b/rig-core/src/providers/openai/client.rs @@ -523,6 +523,7 @@ mod tests { audio: None, name: None, tool_calls: vec![], + extra: serde_json::Map::new(), }; let converted_user_message: message::Message = user_message.clone().try_into().unwrap(); diff --git a/rig-core/src/providers/openai/completion/mod.rs b/rig-core/src/providers/openai/completion/mod.rs index c51b02d79..4e5a176e5 100644 --- a/rig-core/src/providers/openai/completion/mod.rs +++ b/rig-core/src/providers/openai/completion/mod.rs @@ -129,6 +129,8 @@ pub enum Message { skip_serializing_if = "Vec::is_empty" )] tool_calls: Vec, + #[serde(flatten)] + extra: serde_json::Map, }, #[serde(rename = "tool")] ToolResult { @@ -382,6 +384,7 @@ impl TryFrom for Vec { .into_iter() .map(|tool_call| tool_call.into()) .collect::>(), + extra: serde_json::Map::new(), }]) } } @@ -556,6 +559,8 @@ pub struct CompletionResponse { pub system_fingerprint: Option, pub choices: Vec, pub usage: Option, + #[serde(flatten)] + pub extra: serde_json::Map, } impl TryFrom for completion::CompletionResponse { @@ -642,6 +647,8 @@ pub struct Choice { pub struct Usage { pub prompt_tokens: usize, pub total_tokens: usize, + #[serde(flatten)] + pub extra: serde_json::Map, } impl fmt::Display for Usage { @@ -649,6 +656,7 @@ impl fmt::Display for Usage { let Usage { prompt_tokens, total_tokens, + .. } = self; write!( f, diff --git a/rig-core/src/providers/openai/completion/streaming.rs b/rig-core/src/providers/openai/completion/streaming.rs index 6ebb6559a..174e9c7cc 100644 --- a/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig-core/src/providers/openai/completion/streaming.rs @@ -100,7 +100,8 @@ pub async fn send_compatible_streaming_request( let mut final_usage = Usage { prompt_tokens: 0, - total_tokens: 0 + total_tokens: 0, + extra: serde_json::Map::new(), }; let mut partial_data = None;