diff --git a/README.md b/README.md index 507b55e..1d334d6 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +191,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc, ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) diff --git a/crates/rust-mcp-sdk/README.md b/crates/rust-mcp-sdk/README.md index 1581d1d..cbe7318 100644 --- a/crates/rust-mcp-sdk/README.md +++ b/crates/rust-mcp-sdk/README.md @@ -180,7 +180,7 @@ pub struct MyServerHandler; #[async_trait] impl ServerHandler for MyServerHandler { // Handle ListToolsRequest, return list of available tools as ListToolsResult - async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: &dyn McpServer) -> Result { + async fn handle_list_tools_request(&self, request: ListToolsRequest, runtime: Arc) -> Result { Ok(ListToolsResult { tools: vec![SayHelloTool::tool()], @@ -191,7 +191,7 @@ impl ServerHandler for MyServerHandler { } /// Handles requests to call a specific tool. - async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: &dyn McpServer, ) -> Result { + async fn handle_call_tool_request( &self, request: CallToolRequest, runtime: Arc ) -> Result { if request.tool_name() == SayHelloTool::tool_name() { Ok( CallToolResult::text_content( vec![TextContent::from("Hello World!".to_string())] )) diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs index 0a77913..daf5d94 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/hyper_utils.rs @@ -166,11 +166,11 @@ pub async fn start_new_session( let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let runtime: Arc = Arc::new(server_runtime::create_server_instance( + let runtime: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); tracing::info!("a new client joined : {}", &session_id); diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs index e1c00f8..a014e94 100644 --- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs +++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs @@ -99,11 +99,11 @@ pub async fn handle_sse( .unwrap(); let h: Arc = state.handler.clone(); // create a new server instance with unique session_id and - let server: Arc = Arc::new(server_runtime::create_server_instance( + let server: Arc = server_runtime::create_server_instance( Arc::clone(&state.server_details), h, session_id.to_owned(), - )); + ); state .session_store diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs index 89aebf5..9b9577e 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs @@ -1,6 +1,7 @@ use crate::schema::{schema_utils::CallToolError, *}; use async_trait::async_trait; use serde_json::Value; +use std::sync::Arc; use crate::{mcp_traits::mcp_server::McpServer, utils::enforce_compatible_protocol_version}; @@ -15,7 +16,7 @@ pub trait ServerHandler: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, runtime: &dyn McpServer) {} + async fn on_initialized(&self, runtime: Arc) {} /// Handles the InitializeRequest from a client. /// @@ -29,7 +30,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialize_request( &self, initialize_request: InitializeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let mut server_info = runtime.server_info().to_owned(); // Provide compatibility for clients using older MCP protocol versions. @@ -65,7 +66,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_ping_request( &self, _: PingRequest, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result { Ok(Result::default()) } @@ -77,7 +78,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resources_request( &self, request: ListResourcesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -93,7 +94,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_resource_templates_request( &self, request: ListResourceTemplatesRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -109,7 +110,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_read_resource_request( &self, request: ReadResourceRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -125,7 +126,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_subscribe_request( &self, request: SubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -141,7 +142,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_unsubscribe_request( &self, request: UnsubscribeRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -157,7 +158,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_prompts_request( &self, request: ListPromptsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -173,7 +174,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_get_prompt_request( &self, request: GetPromptRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -189,7 +190,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -205,7 +206,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) @@ -220,7 +221,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_set_level_request( &self, request: SetLevelRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -236,7 +237,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_complete_request( &self, request: CompleteRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; Err(RpcError::method_not_found().with_message(format!( @@ -252,7 +253,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_custom_request( &self, request: Value, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Err(RpcError::method_not_found() .with_message("No handler is implemented for custom requests.".to_string())) @@ -265,7 +266,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_initialized_notification( &self, notification: InitializedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -275,7 +276,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_cancelled_notification( &self, notification: CancelledNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -285,7 +286,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_progress_notification( &self, notification: ProgressNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -295,7 +296,7 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_roots_list_changed_notification( &self, notification: RootsListChangedNotification, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -320,18 +321,8 @@ pub trait ServerHandler: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } - - /// Called when the server has successfully started. - /// - /// Sends a "Server started successfully" message to stderr. - /// Customize this function in your specific handler to implement behavior tailored to your MCP server's capabilities and requirements. - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs index e7b0e6d..9275da7 100644 --- a/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler_core.rs @@ -1,8 +1,8 @@ +use crate::mcp_traits::mcp_server::McpServer; use crate::schema::schema_utils::*; use crate::schema::*; use async_trait::async_trait; - -use crate::mcp_traits::mcp_server::McpServer; +use std::sync::Arc; /// Defines the `ServerHandlerCore` trait for handling Model Context Protocol (MCP) server operations. /// Unlike `ServerHandler`, this trait offers no default implementations, providing full control over MCP message handling @@ -14,7 +14,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { /// The `runtime` parameter provides access to the server's runtime environment, allowing /// interaction with the server's capabilities. /// The default implementation does nothing. - async fn on_initialized(&self, _runtime: &dyn McpServer) {} + async fn on_initialized(&self, _runtime: Arc) {} /// Asynchronously handles an incoming request from the client. /// @@ -26,7 +26,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; /// Asynchronously handles an incoming notification from the client. @@ -36,7 +36,7 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_notification( &self, notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; /// Asynchronously handles an error received from the client. @@ -46,11 +46,6 @@ pub trait ServerHandlerCore: Send + Sync + 'static { async fn handle_error( &self, error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result<(), RpcError>; - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs index 44f3e53..57ba260 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs @@ -22,9 +22,10 @@ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use tokio::io::AsyncWriteExt; -use tokio::sync::{oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM"; +const TASK_CHANNEL_CAPACITY: usize = 500; // Define a type alias for the TransportDispatcher trait object type TransportType = Arc< @@ -55,8 +56,6 @@ pub struct ServerRuntime { impl McpServer for ServerRuntime { /// Set the client details, storing them in client_details async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> { - self.handler.on_server_started(self).await; - self.client_details_tx .send(Some(client_details)) .map_err(|_| { @@ -132,8 +131,9 @@ impl McpServer for ServerRuntime { } /// Main runtime loop, processes incoming messages and handles requests - async fn start(&self) -> SdkResult<()> { - let transport_map = self.transport_map.read().await; + async fn start(self: Arc) -> SdkResult<()> { + let self_clone = self.clone(); + let transport_map = self_clone.transport_map.read().await; let transport = transport_map.get(DEFAULT_STREAM_ID).ok_or( RpcError::internal_error() @@ -142,43 +142,88 @@ impl McpServer for ServerRuntime { let mut stream = transport.start().await?; + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + // Process incoming messages from the client while let Some(mcp_messages) = stream.next().await { match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, transport).await; - - match result { - Ok(result) => { - if let Some(result) = result { - transport - .send_message(ServerMessages::Single(result), None) - .await?; + let transport = transport.clone(); + let self = self.clone(); + let tx = tx.clone(); + + // Handle incoming messages in a separate task to avoid blocking the stream. + tokio::spawn(async move { + let result = self.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + // Send result to the main loop + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send result to channel: {}", error); } - Err(error) => { - tracing::error!("Error handling message : {}", error) - } - } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); + let transport = transport.clone(); + let self = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport + .send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => Err(error), + }; - if !results.is_empty() { - transport - .send_message(ServerMessages::Batch(results), None) - .await?; - } + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } } + + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } + return Ok(()); } @@ -223,7 +268,7 @@ impl ServerRuntime { } pub(crate) async fn handle_message( - &self, + self: &Arc, message: ClientMessage, transport: &Arc< dyn TransportDispatcher< @@ -240,7 +285,7 @@ impl ServerRuntime { ClientMessage::Request(client_jsonrpc_request) => { let result = self .handler - .handle_request(client_jsonrpc_request.request, self) + .handle_request(client_jsonrpc_request.request, self.clone()) .await; // create a response to send back to the client let response: MessageFromServer = match result { @@ -262,13 +307,13 @@ impl ServerRuntime { } ClientMessage::Notification(client_jsonrpc_notification) => { self.handler - .handle_notification(client_jsonrpc_notification.notification, self) + .handle_notification(client_jsonrpc_notification.notification, self.clone()) .await?; None } ClientMessage::Error(jsonrpc_error) => { self.handler - .handle_error(&jsonrpc_error.error, self) + .handle_error(&jsonrpc_error.error, self.clone()) .await?; if let Some(tx_response) = transport.pending_request_tx(&jsonrpc_error.id).await { tx_response @@ -282,7 +327,6 @@ impl ServerRuntime { } None } - // The response is the result of a request, it is processed at the transport level. ClientMessage::Response(response) => { if let Some(tx_response) = transport.pending_request_tx(&response.id).await { tx_response @@ -379,7 +423,8 @@ impl ServerRuntime { self.store_transport(stream_id, Arc::new(transport)).await?; - let transport = self.transport_by_stream(stream_id).await?; + let self_clone = self.clone(); + let transport = self_clone.transport_by_stream(stream_id).await?; let (disconnect_tx, mut disconnect_rx) = oneshot::channel::<()>(); let abort_alive_task = transport @@ -397,40 +442,96 @@ impl ServerRuntime { transport.consume_string_payload(&payload).await?; } + // Create a channel to collect results from spawned tasks + let (tx, mut rx) = mpsc::channel(TASK_CHANNEL_CAPACITY); + loop { tokio::select! { Some(mcp_messages) = stream.next() =>{ match mcp_messages { ClientMessages::Single(client_message) => { - let result = self.handle_message(client_message, &transport).await?; - if let Some(result) = result { - transport.send_message(ServerMessages::Single(result), None).await?; - } + let transport = transport.clone(); + let self_clone = self.clone(); + let tx = tx.clone(); + tokio::spawn(async move { + + let result = self_clone.handle_message(client_message, &transport).await; + + let send_result: SdkResult<_> = match result { + Ok(result) => { + if let Some(result) = result { + transport + .send_message(ServerMessages::Single(result), None) + .map_err(|e| e.into()) + .await + } else { + Ok(None) + } + } + Err(error) => { + tracing::error!("Error handling message : {}", error); + Ok(None) + } + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } ClientMessages::Batch(client_messages) => { - let handling_tasks: Vec<_> = client_messages - .into_iter() - .map(|client_message| self.handle_message(client_message, &transport)) - .collect(); - - let results: Vec<_> = try_join_all(handling_tasks).await?; - - let results: Vec<_> = results.into_iter().flatten().collect(); - - - if !results.is_empty() { - transport.send_message(ServerMessages::Batch(results), None).await?; - } + let transport = transport.clone(); + let self_clone = self_clone.clone(); + let tx = tx.clone(); + + tokio::spawn(async move { + let handling_tasks: Vec<_> = client_messages + .into_iter() + .map(|client_message| self_clone.handle_message(client_message, &transport)) + .collect(); + + let send_result = match try_join_all(handling_tasks).await { + Ok(results) => { + let results: Vec<_> = results.into_iter().flatten().collect(); + if !results.is_empty() { + transport.send_message(ServerMessages::Batch(results), None) + .map_err(|e| e.into()) + .await + }else { + Ok(None) + } + }, + Err(error) => Err(error), + }; + if let Err(error) = tx.send(send_result).await { + tracing::error!("Failed to send batch result to channel: {}", error); + } + }); } } + + // Check for results from spawned tasks to propagate errors + while let Ok(result) = rx.try_recv() { + result?; // Propagate errors + } + // close the stream after all messages are sent, unless it is a standalone stream if !stream_id.eq(DEFAULT_STREAM_ID){ + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } return Ok(()); } } _ = &mut disconnect_rx => { + // Drop tx to close the channel and collect remaining results + drop(tx); + while let Some(result) = rx.recv().await { + result?; // Propagate errors + } self.remove_transport(stream_id).await?; // Disconnection detected by keep-alive task return Err(SdkError::connection_closed().into()); @@ -445,10 +546,10 @@ impl ServerRuntime { server_details: Arc, handler: Arc, session_id: SessionId, - ) -> Self { + ) -> Arc { let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details, handler, session_id: Some(session_id), @@ -456,7 +557,7 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } pub(crate) fn new( @@ -469,12 +570,12 @@ impl ServerRuntime { ServerMessage, >, handler: Arc, - ) -> Self { + ) -> Arc { let mut map: HashMap = HashMap::new(); map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport)); let (client_details_tx, client_details_rx) = watch::channel::>(None); - Self { + Arc::new(Self { server_details: Arc::new(server_details), handler, #[cfg(feature = "hyper-server")] @@ -483,6 +584,6 @@ impl ServerRuntime { client_details_tx, client_details_rx, request_id_gen: Box::new(RequestIdGenNumeric::new(None)), - } + }) } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs index ea19e19..5fbc43c 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs @@ -49,7 +49,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandler, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -62,7 +62,7 @@ pub(crate) fn create_server_instance( server_details: Arc, handler: Arc, session_id: SessionId, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new_instance(server_details, handler, session_id) } @@ -80,7 +80,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { match client_jsonrpc_request { schema_utils::RequestFromClient::ClientRequest(client_request) => { @@ -178,7 +178,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -187,7 +187,7 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { match client_jsonrpc_notification { schema_utils::NotificationFromClient::ClientNotification(client_notification) => { @@ -199,7 +199,10 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } ClientNotification::InitializedNotification(initialized_notification) => { self.handler - .handle_initialized_notification(initialized_notification, runtime) + .handle_initialized_notification( + initialized_notification, + runtime.clone(), + ) .await?; self.handler.on_initialized(runtime).await; } @@ -226,8 +229,4 @@ impl McpServerHandler for ServerRuntimeInternalHandler> { } Ok(()) } - - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs index e0e7108..5ed2239 100644 --- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs +++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs @@ -43,7 +43,7 @@ pub fn create_server( ServerMessage, >, handler: impl ServerHandlerCore, -) -> ServerRuntime { +) -> Arc { ServerRuntime::new( server_details, transport, @@ -66,7 +66,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // store the client details if the request is a client initialization request if let schema_utils::RequestFromClient::ClientRequest(ClientRequest::InitializeRequest( @@ -88,7 +88,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { self.handler.handle_error(jsonrpc_error, runtime).await?; Ok(()) @@ -96,11 +96,11 @@ impl McpServerHandler for RuntimeCoreInternalHandler> async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()> { // Trigger the `on_initialized()` callback if an `initialized_notification` is received from the client. if client_jsonrpc_notification.is_initialized_notification() { - self.handler.on_initialized(runtime).await; + self.handler.on_initialized(runtime.clone()).await; } // handle notification @@ -109,7 +109,4 @@ impl McpServerHandler for RuntimeCoreInternalHandler> .await?; Ok(()) } - async fn on_server_started(&self, runtime: &dyn McpServer) { - self.handler.on_server_started(runtime).await; - } } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs index 2974bfc..cb37f2a 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_handler.rs @@ -6,9 +6,9 @@ use crate::schema::schema_utils::{NotificationFromClient, RequestFromClient, Res #[cfg(feature = "client")] use crate::schema::schema_utils::{NotificationFromServer, RequestFromServer, ResultFromClient}; -use crate::schema::RpcError; - use crate::error::SdkResult; +use crate::schema::RpcError; +use std::sync::Arc; #[cfg(feature = "client")] use super::mcp_client::McpClient; @@ -18,21 +18,20 @@ use super::mcp_server::McpServer; #[cfg(feature = "server")] #[async_trait] pub trait McpServerHandler: Send + Sync { - async fn on_server_started(&self, runtime: &dyn McpServer); async fn handle_request( &self, client_jsonrpc_request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result; async fn handle_error( &self, jsonrpc_error: &RpcError, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; async fn handle_notification( &self, client_jsonrpc_notification: NotificationFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> SdkResult<()>; } diff --git a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs index 2eab9db..dc860b6 100644 --- a/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs +++ b/crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs @@ -13,16 +13,15 @@ use crate::schema::{ ResourceUpdatedNotification, ResourceUpdatedNotificationParams, RpcError, ServerCapabilities, SetLevelRequest, ToolListChangedNotification, ToolListChangedNotificationParams, }; +use crate::{error::SdkResult, utils::format_assertion_message}; use async_trait::async_trait; use rust_mcp_transport::SessionId; -use std::time::Duration; - -use crate::{error::SdkResult, utils::format_assertion_message}; +use std::{sync::Arc, time::Duration}; //TODO: support options , such as enforceStrictCapabilities #[async_trait] pub trait McpServer: Sync + Send { - async fn start(&self) -> SdkResult<()>; + async fn start(self: Arc) -> SdkResult<()>; async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>; fn server_info(&self) -> &InitializeResult; fn client_info(&self) -> Option; diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index aa8e2fb..176e0d2 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -17,7 +17,7 @@ pub mod test_server_common { mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, McpServer, SessionId, }; - use std::sync::RwLock; + use std::sync::{Arc, RwLock}; use std::time::Duration; use tokio::time::timeout; @@ -71,16 +71,10 @@ pub mod test_server_common { #[async_trait] impl ServerHandler for TestServerHandler { - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; - } - async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime.assert_server_request_capabilities(request.method())?; @@ -94,7 +88,7 @@ pub mod test_server_common { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { runtime .assert_server_request_capabilities(request.method()) diff --git a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs index 5c184cf..9f2fd95 100644 --- a/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs +++ b/crates/rust-mcp-sdk/tests/test_protocol_compatibility.rs @@ -30,7 +30,7 @@ mod protocol_compatibility_on_server { ); handler - .handle_initialize_request(InitializeRequest::new(initialize_request), &runtime) + .handle_initialize_request(InitializeRequest::new(initialize_request), runtime) .await } diff --git a/crates/rust-mcp-transport/src/stdio.rs b/crates/rust-mcp-transport/src/stdio.rs index 582af5d..0b67d64 100644 --- a/crates/rust-mcp-transport/src/stdio.rs +++ b/crates/rust-mcp-transport/src/stdio.rs @@ -210,13 +210,12 @@ where .take() .ok_or_else(|| TransportError::FromString("Unable to retrieve stderr.".into()))?; - let pending_requests_clone1 = self.pending_requests.clone(); - let pending_requests_clone2 = self.pending_requests.clone(); + let pending_requests_clone = self.pending_requests.clone(); tokio::spawn(async move { let _ = process.wait().await; // clean up pending requests to cancel waiting tasks - let mut pending_requests = pending_requests_clone1.lock().await; + let mut pending_requests = pending_requests_clone.lock().await; pending_requests.clear(); }); @@ -224,7 +223,7 @@ where Box::pin(stdout), Mutex::new(Box::pin(stdin)), IoStream::Readable(Box::pin(stderr)), - pending_requests_clone2, + self.pending_requests.clone(), self.options.timeout, cancellation_token, ); diff --git a/doc/getting-started-mcp-server.md b/doc/getting-started-mcp-server.md index 358b1b4..6fac258 100644 --- a/doc/getting-started-mcp-server.md +++ b/doc/getting-started-mcp-server.md @@ -160,7 +160,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, _request: ListToolsRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -173,7 +173,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - _runtime: &dyn McpServer, + _runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server-core/src/handler.rs b/examples/hello-world-mcp-server-core/src/handler.rs index f0bdefe..acf55ea 100644 --- a/examples/hello-world-mcp-server-core/src/handler.rs +++ b/examples/hello-world-mcp-server-core/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -90,7 +92,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -99,7 +101,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-mcp-server/src/handler.rs b/examples/hello-world-mcp-server/src/handler.rs index d9741a0..47925a0 100644 --- a/examples/hello-world-mcp-server/src/handler.rs +++ b/examples/hello-world-mcp-server/src/handler.rs @@ -4,6 +4,7 @@ use rust_mcp_sdk::schema::{ ListToolsResult, RpcError, }; use rust_mcp_sdk::{mcp_server::ServerHandler, McpServer}; +use std::sync::Arc; use crate::tools::GreetingTools; @@ -20,7 +21,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +34,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = diff --git a/examples/hello-world-mcp-server/src/main.rs b/examples/hello-world-mcp-server/src/main.rs index 00ca6a7..98ff6f0 100644 --- a/examples/hello-world-mcp-server/src/main.rs +++ b/examples/hello-world-mcp-server/src/main.rs @@ -1,6 +1,8 @@ mod handler; mod tools; +use std::sync::Arc; + use handler::MyServerHandler; use rust_mcp_sdk::schema::{ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, @@ -40,7 +42,8 @@ async fn main() -> SdkResult<()> { let handler = MyServerHandler {}; // STEP 4: create a MCP server - let server: ServerRuntime = server_runtime::create_server(server_details, transport, handler); + let server: Arc = + server_runtime::create_server(server_details, transport, handler); // STEP 5: Start the server if let Err(start_error) = server.start().await { diff --git a/examples/hello-world-server-core-streamable-http/src/handler.rs b/examples/hello-world-server-core-streamable-http/src/handler.rs index 1c69e8c..7941075 100644 --- a/examples/hello-world-server-core-streamable-http/src/handler.rs +++ b/examples/hello-world-server-core-streamable-http/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ @@ -22,7 +24,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_request( &self, request: RequestFromClient, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { let method_name = &request.method().to_owned(); match request { @@ -95,7 +97,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_notification( &self, notification: NotificationFromClient, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } @@ -104,7 +106,7 @@ impl ServerHandlerCore for MyServerHandler { async fn handle_error( &self, error: &RpcError, - _: &dyn McpServer, + _: Arc, ) -> std::result::Result<(), RpcError> { Ok(()) } diff --git a/examples/hello-world-server-streamable-http/src/handler.rs b/examples/hello-world-server-streamable-http/src/handler.rs index b8ce355..c4732d2 100644 --- a/examples/hello-world-server-streamable-http/src/handler.rs +++ b/examples/hello-world-server-streamable-http/src/handler.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use async_trait::async_trait; use rust_mcp_sdk::schema::{ schema_utils::CallToolError, CallToolRequest, CallToolResult, ListToolsRequest, @@ -20,7 +22,7 @@ impl ServerHandler for MyServerHandler { async fn handle_list_tools_request( &self, request: ListToolsRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { Ok(ListToolsResult { meta: None, @@ -33,7 +35,7 @@ impl ServerHandler for MyServerHandler { async fn handle_call_tool_request( &self, request: CallToolRequest, - runtime: &dyn McpServer, + runtime: Arc, ) -> std::result::Result { // Attempt to convert request parameters into GreetingTools enum let tool_params: GreetingTools = @@ -45,6 +47,4 @@ impl ServerHandler for MyServerHandler { GreetingTools::SayGoodbyeTool(say_goodbye_tool) => say_goodbye_tool.call_tool(), } } - - async fn on_server_started(&self, runtime: &dyn McpServer) {} }