Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions rig-core/examples/lmstudio_streaming.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use rig::agent::stream_to_stdout;
use rig::prelude::*;
use rig::providers::{lmstudio, openai};
use rig::streaming::StreamingPrompt;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Uncomment tracing for debugging
tracing_subscriber::fmt().init();

// Create streaming agent with a single context prompt
let agent = lmstudio::Client::from_env()
.agent(openai::GPT_4O)
.preamble("Be precise and concise.")
.temperature(0.5)
.build();

// Stream the response and print chunks as they arrive
let mut stream = agent
.stream_prompt("When and where and what type is the next solar eclipse?")
.await;

let res = stream_to_stdout(&mut stream).await?;

println!("Token usage response: {usage:?}", usage = res.usage());
println!("Final text response: {message:?}", message = res.response());

Ok(())
}
3 changes: 2 additions & 1 deletion rig-core/src/providers/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ pub async fn send_compatible_streaming_request(

let mut final_usage = Usage {
prompt_tokens: 0,
total_tokens: 0
total_tokens: 0,
extra: serde_json::Map::new(),
};

let mut partial_data = None;
Expand Down
90 changes: 90 additions & 0 deletions rig-core/src/providers/lmstudio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use crate::prelude::CompletionClient;
use crate::prelude::EmbeddingsClient;
use crate::prelude::ProviderClient;
use crate::prelude::TranscriptionClient;
use crate::providers::openai;
pub use openai::completion::*;

const LMSTUDIO_API_BASE_URL: &str = "http://localhost:8080/v1";

/// A client for the LM Studio API.
#[derive(Clone, Debug)]
pub struct Client {
inner: openai::Client,
}

impl ProviderClient for Client {
fn from_env() -> Self {
let base_url = std::env::var("LMSTUDIO_API_BASE")
.unwrap_or_else(|_| LMSTUDIO_API_BASE_URL.to_string());
let api_key = std::env::var("LMSTUDIO_API_KEY").unwrap_or_else(|_| "lmstudio".to_string());

let inner = openai::Client::builder(&api_key)
.base_url(&base_url)
.build()
.expect("Failed to build LM Studio client");

Self { inner }
}

fn from_val(input: crate::client::ProviderValue) -> Self {
let crate::client::ProviderValue::Simple(api_key) = input else {
panic!("Incorrect provider value type")
};
let base_url = std::env::var("LMSTUDIO_API_BASE")
.unwrap_or_else(|_| LMSTUDIO_API_BASE_URL.to_string());

let inner = openai::Client::builder(&api_key)
.base_url(&base_url)
.build()
.expect("Failed to build LM Studio client");

Self { inner }
}
}

impl CompletionClient for Client {
type CompletionModel = openai::responses_api::ResponsesCompletionModel;

fn completion_model(&self, model: &str) -> Self::CompletionModel {
self.inner.completion_model(model)
}
}

impl EmbeddingsClient for Client {
type EmbeddingModel = openai::embedding::EmbeddingModel;

fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
self.inner.embedding_model(model)
}

fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
self.inner.embedding_model_with_ndims(model, ndims)
}
}

impl TranscriptionClient for Client {
type TranscriptionModel = openai::transcription::TranscriptionModel;

fn transcription_model(&self, model: &str) -> Self::TranscriptionModel {
self.inner.transcription_model(model)
}
}

#[cfg(feature = "image")]
impl crate::client::ImageGenerationClient for Client {
type ImageGenerationModel = openai::image_generation::ImageGenerationModel;

fn image_generation_model(&self, model: &str) -> Self::ImageGenerationModel {
self.inner.image_generation_model(model)
}
}

#[cfg(feature = "audio")]
impl crate::client::AudioGenerationClient for Client {
type AudioGenerationModel = openai::audio_generation::AudioGenerationModel;

fn audio_generation_model(&self, model: &str) -> Self::AudioGenerationModel {
self.inner.audio_generation_model(model)
}
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pub mod gemini;
pub mod groq;
pub mod huggingface;
pub mod hyperbolic;
pub mod lmstudio;
pub mod mira;
pub mod mistral;
pub mod moonshot;
Expand Down
1 change: 1 addition & 0 deletions rig-core/src/providers/openai/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ mod tests {
audio: None,
name: None,
tool_calls: vec![],
extra: serde_json::Map::new(),
};

let converted_user_message: message::Message = user_message.clone().try_into().unwrap();
Expand Down
8 changes: 8 additions & 0 deletions rig-core/src/providers/openai/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ pub enum Message {
skip_serializing_if = "Vec::is_empty"
)]
tool_calls: Vec<ToolCall>,
#[serde(flatten)]
extra: serde_json::Map<String, serde_json::Value>,
},
#[serde(rename = "tool")]
ToolResult {
Expand Down Expand Up @@ -382,6 +384,7 @@ impl TryFrom<message::Message> for Vec<Message> {
.into_iter()
.map(|tool_call| tool_call.into())
.collect::<Vec<_>>(),
extra: serde_json::Map::new(),
}])
}
}
Expand Down Expand Up @@ -556,6 +559,8 @@ pub struct CompletionResponse {
pub system_fingerprint: Option<String>,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
#[serde(flatten)]
pub extra: serde_json::Map<String, serde_json::Value>,
}

impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
Expand Down Expand Up @@ -642,13 +647,16 @@ pub struct Choice {
pub struct Usage {
pub prompt_tokens: usize,
pub total_tokens: usize,
#[serde(flatten)]
pub extra: serde_json::Map<String, serde_json::Value>,
}

impl fmt::Display for Usage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Usage {
prompt_tokens,
total_tokens,
..
} = self;
write!(
f,
Expand Down
3 changes: 2 additions & 1 deletion rig-core/src/providers/openai/completion/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ pub async fn send_compatible_streaming_request(

let mut final_usage = Usage {
prompt_tokens: 0,
total_tokens: 0
total_tokens: 0,
extra: serde_json::Map::new(),
};

let mut partial_data = None;
Expand Down