Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion discord_rig_bot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2021"

[dependencies]
rig-core = "0.2.1"
rig-core = "0.9"
tokio = { version = "1.34.0", features = ["full"] }
serenity = { version = "0.11", default-features = false, features = ["client", "gateway", "rustls_backend", "cache", "model", "http"] }
dotenv = "0.15.0"
Expand Down
160 changes: 115 additions & 45 deletions discord_rig_bot/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
mod rig_agent;

use anyhow::Result;
use dotenv::dotenv;
use rig_agent::RigAgent;
use serenity::async_trait;
use serenity::model::application::command::Command;
use serenity::model::application::command::CommandOptionType;
use serenity::model::application::interaction::{Interaction, InteractionResponseType};
use serenity::model::gateway::Ready;
use serenity::model::channel::Message;
use serenity::model::gateway::Ready;
use serenity::prelude::*;
use serenity::model::application::command::CommandOptionType;
use std::env;
use std::sync::Arc;
use tracing::{error, info, debug};
use rig_agent::RigAgent;
use dotenv::dotenv;
use tracing::{debug, error, info};

// Define a key for storing the bot's user ID in the TypeMap
struct BotUserId;
Expand All @@ -30,51 +30,93 @@ struct Handler {
#[async_trait]
impl EventHandler for Handler {
async fn interaction_create(&self, ctx: Context, interaction: Interaction) {
debug!("Received an interaction");
debug!("\n\n======> Received an interaction");
if let Interaction::ApplicationCommand(command) = interaction {
debug!("Received command: {}", command.data.name);
let content = match command.data.name.as_str() {
"hello" => "Hello! I'm your helpful Rust and Rig-powered assistant. How can I assist you today?".to_string(),
debug!("\n\n======> Received command: {}", command.data.name);

match command.data.name.as_str() {
"hello" => {
let content = "Hello! I'm your helpful Rust and Rig-powered assistant. How can I assist you today?".to_string();

if let Err(why) = command
.create_interaction_response(&ctx.http, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|message| message.content(content))
})
.await
{
error!("Cannot respond to slash command: {}", why);
}
},
"ask" => {
// Step 1: Acknowledge quickly
if let Err(e) = command
.create_interaction_response(&ctx.http, |response| {
response.kind(InteractionResponseType::DeferredChannelMessageWithSource)
})
.await
{
error!("Failed to create deferred response: {:?}", e);
return;
}

let query = command
.data
.options
.get(0)
.and_then(|opt| opt.value.as_ref())
.and_then(|v| v.as_str())
.unwrap_or("What would you like to ask?");
debug!("Query: {}", query);
match self.rig_agent.process_message(query).await {
Ok(response) => response,

debug!("\n\n======> Query: {}", query);

let response = match self.rig_agent.process_string(query).await {
Ok(response) => {
if response.len() > 1900 {
format!("Response truncated due to Discord limits:\n{}", &response[..1897])
} else {
response
}
},
Err(e) => {
error!("Error processing request: {:?}", e);
format!("Error processing request: {:?}", e)
}
};

// Step 3: Edit the original response
if let Err(e) = command
.edit_original_interaction_response(&ctx.http, |message| {
message.content(response)
})
.await
{
error!("Failed to edit interaction response: {:?}", e);
}
},
_ => {
if let Err(why) = command
.create_interaction_response(&ctx.http, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|message|
message.content("Not implemented :("))
})
.await
{
error!("Cannot respond to slash command: {}", why);
}
}
_ => "Not implemented :(".to_string(),
};

debug!("Sending response: {}", content);

if let Err(why) = command
.create_interaction_response(&ctx.http, |response| {
response
.kind(InteractionResponseType::ChannelMessageWithSource)
.interaction_response_data(|message| message.content(content))
})
.await
{
error!("Cannot respond to slash command: {}", why);
} else {
debug!("Response sent successfully");
}

debug!("\n\n======> Response sent successfully");
}
}

async fn message(&self, ctx: Context, msg: Message) {
if msg.mentions_me(&ctx.http).await.unwrap_or(false) {
debug!("Bot mentioned in message: {}", msg.content);
debug!("\n\n=====> Bot mentioned in message: {}", msg.content);

let bot_id = {
let data = ctx.data.read().await;
Expand All @@ -85,25 +127,55 @@ impl EventHandler for Handler {
let mention = format!("<@{}>", bot_id);
let content = msg.content.replace(&mention, "").trim().to_string();

debug!("Processed content after removing mention: {}", content);
debug!(
"\n\n=====> Processed content after removing mention: {}",
content
);

match self.rig_agent.process_message(&content).await {
match self.rig_agent.process_message(&ctx, &msg).await {
Ok(response) => {
if let Err(why) = msg.channel_id.say(&ctx.http, response).await {
error!("Error sending message: {:?}", why);
println!("Response sent successfully.");
println!("{}", response);
}
Err(e) => {
println!("Error processing request: {:?}", e);
if let Err(why) = msg.channel_id.say(&ctx.http, format!("Error processing request: {:?}", e)).await {
println!("Error sending error message: {:?}", why);
}
}
}

match self.rig_agent.process_message(&ctx, &msg).await {
Ok(response) => {
println!("Response sent successfully.");
println!("{}", response);
}
Err(e) => {
error!("Error processing message: {:?}", e);
if let Err(why) = msg
.channel_id
.say(&ctx.http, format!("Error processing message: {:?}", e))
.await
{
error!("Error sending error message: {:?}", why);
println!("Error processing request: {:?}", e);
if let Err(why) = msg.channel_id.say(&ctx.http, format!("Error processing request: {:?}", e)).await {
println!("Error sending error message: {:?}", why);
}
}
}


// match self.rig_agent.process_message(&content).await {
// Ok(response) => {
// if let Err(why) = msg.channel_id.say(&ctx.http, response).await {
// error!("Error sending message: {:?}", why);
// }
// }
// Err(e) => {
// error!("Error processing message: {:?}", e);
// if let Err(why) = msg
// .channel_id
// .say(&ctx.http, format!("Error processing message: {:?}", e))
// .await
// {
// error!("Error sending error message: {:?}", why);
// }
// }
// }
} else {
error!("Bot user ID not found in TypeMap");
}
Expand All @@ -121,9 +193,7 @@ impl EventHandler for Handler {
let commands = Command::set_global_application_commands(&ctx.http, |commands| {
commands
.create_application_command(|command| {
command
.name("hello")
.description("Say hello to the bot")
command.name("hello").description("Say hello to the bot")
})
.create_application_command(|command| {
command
Expand Down Expand Up @@ -172,4 +242,4 @@ async fn main() -> Result<()> {
}

Ok(())
}
}
86 changes: 63 additions & 23 deletions discord_rig_bot/src/rig_agent.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// rig_agent.rs

use anyhow::{Context, Result};
use rig::providers::openai;
use rig::vector_store::in_memory_store::InMemoryVectorStore;
use rig::vector_store::VectorStore;
use rig::embeddings::EmbeddingsBuilder;
use rig::agent::Agent;
use rig::completion::Prompt;
use std::path::Path;
use rig::{
agent::Agent, completion::Prompt, embeddings::EmbeddingsBuilder, providers::openai,
vector_store::in_memory_store::InMemoryVectorStore,
};
use std::fs;
use std::path::Path;
use std::sync::Arc;

use serenity::client::Context as SerenityContext;
use serenity::model::channel::Message;

pub struct RigAgent {
agent: Arc<Agent<openai::CompletionModel>>,
}
Expand All @@ -37,36 +38,40 @@ impl RigAgent {
let md2_content = Self::load_md_content(&md2_path)?;
let md3_content = Self::load_md_content(&md3_path)?;

// Create embeddings and add to vector store
//Create embeddings add to vector store
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.simple_document("Rig_guide", &md1_content)
.simple_document("Rig_faq", &md2_content)
.simple_document("Rig_examples", &md3_content)
.document(md1_content)?
.document(md2_content)?
.document(md3_content)?
.build()
.await?;

vector_store.add_documents(embeddings).await?;
vector_store.add_documents(embeddings);

// Create index
let index = vector_store.index(embedding_model);

// Create Agent
let agent = Arc::new(openai_client.agent(openai::GPT_4O)
.preamble("You are an advanced AI assistant powered by Rig, a Rust library for building LLM applications. Your primary function is to provide accurate, helpful, and context-aware responses by leveraging both your general knowledge and specific information retrieved from a curated knowledge base.
let agent = Arc::new(
openai_client
.agent(openai::GPT_4O)
.preamble(
"You are an advanced AI assistant powered by Rig, a Rust library for building LLM applications. Your primary function is to provide accurate, helpful, and context-aware responses by leveraging both your general knowledge and specific information retrieved from a curated knowledge base.

Key responsibilities and behaviors:
1. Information Retrieval: You have access to a vast knowledge base. When answering questions, always consider the context provided by the retrieved information.
2. Clarity and Conciseness: Provide clear and concise answers. Ensure responses are short and concise. Use bullet points or numbered lists for complex information when appropriate.
3. Technical Proficiency: You have deep knowledge about Rig and its capabilities. When discussing Rig or answering related questions, provide detailed and technically accurate information.
4. Code Examples: When appropriate, provide Rust code examples to illustrate concepts, especially when discussing Rig's functionalities. Always format code examples for proper rendering in Discord by wrapping them in triple backticks and specifying the language as 'rust'. For example:
5. Code Examples: When appropriate, provide Rust code examples to illustrate concepts, especially when discussing Rig's functionalities. Always format code examples for proper rendering in Discord by wrapping them in triple backticks and specifying the language as 'rust'. For example:
```rust
let example_code = \"This is how you format Rust code for Discord\";
println!(\"{}\", example_code);
```
5. Keep your responses short and concise. If the user needs more information, they can ask follow-up questions.
")
.dynamic_context(2, index)
.build());
",
)
.dynamic_context(2, index)
.build(),
);

Ok(Self { agent })
}
Expand All @@ -75,8 +80,43 @@ impl RigAgent {
fs::read_to_string(file_path.as_ref())
.with_context(|| format!("Failed to read markdown file: {:?}", file_path.as_ref()))
}

pub async fn process_message(&self, message: &str) -> Result<String> {
self.agent.prompt(message).await.map_err(anyhow::Error::from)

// Add this function for messages that only need a string input/output
pub async fn process_string(&self, message: &str) -> Result<String> {
self.agent
.prompt(message)
.await
.map_err(anyhow::Error::from)
}
}

pub async fn process_message(&self, ctx: &SerenityContext, msg: &Message) -> Result<String> {
// First, create a typing indicator
msg.channel_id.broadcast_typing(&ctx.http).await?;

// Send deferred response to meet 3-second requirement
let mut deferred_msg = msg.channel_id.say(&ctx.http, "Thinking...").await?;

// Use the string content directly, not a reference
let response = self.agent.prompt(msg.content.clone()).await.map_err(anyhow::Error::from)?;

// Truncate if needed
let truncated_response = if response.len() > 1900 {
format!("Response truncated due to Discord limits:\n{}", &response[..1897])
} else {
response
};

// Edit the deferred message
deferred_msg.edit(&ctx.http, |m| m.content(truncated_response.clone())).await?;

Ok(truncated_response)
}

// OLD process_message WITHOUT DEFERRAL AND TRUNCATION
// pub async fn process_message(&self, message: &str) -> Result<String> {
// self.agent
// .prompt(message)
// .await
// .map_err(anyhow::Error::from)
// }
}