Skip to content

Commit 61acb34

Browse files
committed
feat: Add context_tool for mpc servers
1 parent 78b9746 commit 61acb34

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

src/agent/mod.rs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use event::ConfirmToolResponse;
3434
use futures::StreamExt;
3535
use itertools::Itertools;
3636
use mcp_core::types::ProtocolVersion;
37+
use mcp_core::types::ToolResponseContent;
3738
use rig::agent::AgentBuilder;
3839
use rig::completion::CompletionError;
3940
use rig::completion::CompletionModel;
@@ -201,19 +202,20 @@ impl Agent {
201202
agent_builder
202203
}
203204

204-
async fn add_mcp_tools<M>(
205+
async fn add_mcp_tools<'a, M>(
205206
mut agent_builder: AgentBuilder<M>,
206207
mcp: Option<&McpConfig>,
207-
) -> Result<AgentBuilder<M>>
208+
) -> Result<(AgentBuilder<M>, String)>
208209
where
209210
M: CompletionModel,
210211
{
211212
let Some(mcp_config) = mcp else {
212-
return Ok(agent_builder);
213+
return Ok((agent_builder, String::default()));
213214
};
214215

216+
let mut system_prompt_addons = Vec::default();
215217
for server_config in mcp_config.servers.values() {
216-
match server_config {
218+
match &server_config.transport {
217219
McpClientTransport::Stdio(config) => {
218220
let transport = mcp_core::transport::ClientStdioTransport::new(
219221
&config.command,
@@ -261,16 +263,43 @@ impl Agent {
261263
})?;
262264
let tools_list_res = mcp_client.list_tools(None, None).await?;
263265

266+
if let Some(system_prompt_template) = &server_config.system_prompt {
267+
if let Some(context_tool) = &server_config.context_tool {
268+
let result = mcp_client.call_tool(&context_tool, None).await?;
269+
if result.is_error.is_none_or(|is_error| !is_error) {
270+
let txt = result
271+
.content
272+
.iter()
273+
.filter_map(|content| {
274+
if let ToolResponseContent::Text(txt) = content {
275+
Some(txt.text.clone())
276+
} else {
277+
None
278+
}
279+
})
280+
.join("\\n");
281+
let system_prompt =
282+
system_prompt_template.replace("{CONTEXT_TOOL}", &txt);
283+
system_prompt_addons.push(system_prompt);
284+
}
285+
}
286+
}
264287
agent_builder = tools_list_res
265288
.tools
266289
.into_iter()
290+
.filter(|tool| {
291+
server_config
292+
.context_tool
293+
.as_ref()
294+
.is_none_or(|ctx_tool| ctx_tool != &tool.name)
295+
})
267296
.fold(agent_builder, |builder, tool| {
268297
builder.mcp_tool(tool, mcp_client.clone())
269298
})
270299
}
271300
}
272301
}
273-
Ok(agent_builder)
302+
Ok((agent_builder, system_prompt_addons.join("\\n")))
274303
}
275304

276305
async fn configure_agent<M>(
@@ -281,13 +310,14 @@ impl Agent {
281310
where
282311
M: CompletionModel,
283312
{
284-
agent_builder = agent_builder
285-
.preamble(&context.system_prompt)
286-
.temperature(0.0);
313+
agent_builder = agent_builder.temperature(0.0);
314+
let mut system_prompt = context.system_prompt.clone();
287315
let mcp_config = context.config.mcp.as_ref();
288316
agent_builder = Self::add_static_tools(agent_builder, context);
289-
agent_builder = Self::add_mcp_tools(agent_builder, mcp_config).await?;
290-
let agent = agent_builder.build();
317+
let (agent_builder, system_prompt_addons) =
318+
Self::add_mcp_tools(agent_builder, mcp_config).await?;
319+
system_prompt.push_str(&system_prompt_addons);
320+
let agent = agent_builder.preamble(&system_prompt).build();
291321
*tools_tokens = count_tokens(
292322
&agent
293323
.tools

src/config.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,16 @@ pub enum McpClientTransport {
3939
Sse(McpClientSseTransport),
4040
}
4141

42+
#[derive(Debug, Deserialize, Clone)]
43+
pub struct McpClientConfig {
44+
pub transport: McpClientTransport,
45+
pub context_tool: Option<String>,
46+
/// System prompt to use for the agent with placeholder {CONTEXT_TOOL} will be replaced with the context tool result
47+
pub system_prompt: Option<String>,
48+
}
4249
#[derive(Debug, Deserialize, Clone)]
4350
pub struct McpConfig {
44-
pub servers: HashMap<String, McpClientTransport>,
51+
pub servers: HashMap<String, McpClientConfig>,
4552
}
4653

4754
#[derive(Debug, Deserialize, Clone)]

0 commit comments

Comments
 (0)