-
Notifications
You must be signed in to change notification settings - Fork 479
Expand file tree
/
Copy pathprogress_demo.rs
More file actions
130 lines (117 loc) · 3.8 KB
/
progress_demo.rs
File metadata and controls
130 lines (117 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use futures::Stream;
use rmcp::{
ErrorData as McpError, RoleServer, ServerHandler, handler::server::tool::ToolRouter, model::*,
service::RequestContext, tool, tool_handler, tool_router,
};
use serde_json::json;
use tokio_stream::StreamExt;
use tracing::debug;
// a Stream data source that generates data in chunks
#[derive(Clone)]
struct StreamDataSource {
data: Vec<u8>,
chunk_size: usize,
position: usize,
}
impl StreamDataSource {
pub fn new(data: Vec<u8>, chunk_size: usize) -> Self {
Self {
data,
chunk_size,
position: 0,
}
}
pub fn from_text(text: &str) -> Self {
Self::new(text.as_bytes().to_vec(), 1)
}
}
impl Stream for StreamDataSource {
type Item = Result<Vec<u8>, io::Error>;
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.position >= this.data.len() {
return Poll::Ready(None);
}
let start = this.position;
let end = (start + this.chunk_size).min(this.data.len());
let chunk = this.data[start..end].to_vec();
this.position = end;
Poll::Ready(Some(Ok(chunk)))
}
}
#[derive(Clone)]
pub struct ProgressDemo {
data_source: StreamDataSource,
tool_router: ToolRouter<Self>,
}
#[tool_router]
impl ProgressDemo {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
tool_router: Self::tool_router(),
data_source: StreamDataSource::from_text("Hello, world!"),
}
}
#[tool(description = "Process data stream with progress updates")]
async fn stream_processor(
&self,
ctx: RequestContext<RoleServer>,
) -> Result<CallToolResult, McpError> {
let mut counter = 0;
let mut data_source = self.data_source.clone();
loop {
let chunk = data_source.next().await;
if chunk.is_none() {
break;
}
let chunk = chunk.unwrap().unwrap();
let chunk_str = String::from_utf8_lossy(&chunk);
counter += 1;
// create progress notification param
let progress_param = ProgressNotificationParam {
progress_token: ProgressToken(NumberOrString::Number(counter)),
progress: counter as f64,
total: None,
message: Some(chunk_str.to_string()),
};
match ctx.peer.notify_progress(progress_param).await {
Ok(_) => {
debug!("Processed record: {}", chunk_str);
}
Err(e) => {
return Err(McpError::internal_error(
format!("Failed to notify progress: {}", e),
Some(json!({
"record": chunk_str,
"progress": counter,
"error": e.to_string()
})),
));
}
}
}
Ok(CallToolResult::success(vec![Content::text(format!(
"Processed {} records successfully",
counter
))]))
}
}
#[tool_handler]
impl ServerHandler for ProgressDemo {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
.with_protocol_version(ProtocolVersion::V_2024_11_05)
.with_server_info(Implementation::from_build_env())
.with_instructions(
"This server demonstrates progress notifications during long-running operations. \
Use the tools to see real-time progress updates for batch processing"
.to_string(),
)
}
}