diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 3476390a..6cc678ac 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -102,7 +102,7 @@ //! start a MCP server in Python and then list the tools and call `git status` //! as follows: //! -//! ```rust +//! ```rust,ignore //! use anyhow::Result; //! use rmcp::{model::CallToolRequestParam, service::ServiceExt, transport::{TokioChildProcess, ConfigureCommandExt}}; //! use tokio::process::Command; diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 81286fed..52c798ed 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -39,7 +39,7 @@ //! //! ## Examples //! -//! ```rust +//! ```rust,ignore //! # use rmcp::{ //! # ServiceExt, serve_client, serve_server, //! # }; diff --git a/crates/rmcp/src/transport/common/client_side_sse.rs b/crates/rmcp/src/transport/common/client_side_sse.rs index 2a5a4465..4e01994f 100644 --- a/crates/rmcp/src/transport/common/client_side_sse.rs +++ b/crates/rmcp/src/transport/common/client_side_sse.rs @@ -98,10 +98,29 @@ impl SseStreamReconnect for NeverReconnect { } } +/// Abstraction for SSE reconnection logic. Implementors can hook into +/// [`handle_control_event`](Self::handle_control_event) to consume control +/// frames (e.g. `event: endpoint`) that arrive when a server restarts an SSE +/// stream. The default implementation is a no-op, keeping existing behaviour +/// intact. pub(crate) trait SseStreamReconnect { type Error: std::error::Error; type Future: Future> + Send; fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future; + fn handle_control_event(&mut self, _event: &Sse) -> Result<(), Self::Error> { + Ok(()) + } + fn handle_stream_error( + &mut self, + error: &(dyn std::error::Error + 'static), + last_event_id: Option<&str>, + ) { + if let Some(id) = last_event_id { + tracing::warn!(%id, "sse stream error: {error}"); + } else { + tracing::warn!("sse stream error: {error}"); + } + } } pin_project_lite::pin_project! { @@ -189,14 +208,31 @@ where *this.server_retry_interval = Some(Duration::from_millis(new_server_retry)); } - if let Some(event_id) = sse.id { - *this.last_event_id = Some(event_id); + if let Some(ref event_id) = sse.id { + *this.last_event_id = Some(event_id.clone()); + } + // Only treat blank/`message` events as JSON-RPC payloads. + // Other control frames (endpoint, ping, etc.) are passed to + // the reconnection handler. + let is_message_event = + matches!(sse.event.as_deref(), None | Some("") | Some("message")); + if !is_message_event { + match this.connector.handle_control_event(&sse) { + Ok(()) => return self.poll_next(cx), + Err(e) => { + this.state.set(SseAutoReconnectStreamState::Terminated); + return Poll::Ready(Some(Err(e))); + } + } } if let Some(data) = sse.data { match serde_json::from_str::(&data) { Err(e) => { - // not sure should this be a hard error - tracing::warn!("failed to deserialize server message: {e}"); + // Downgrade to debug to avoid noisy logs when servers emit + // non-JSON payloads as message frames. Include last_event_id + // to aid troubleshooting while keeping default behaviour. + let last_id = this.last_event_id.as_deref().unwrap_or(""); + tracing::debug!(last_event_id=%last_id, "failed to deserialize server message: {e}"); return self.poll_next(cx); } Ok(message) => { @@ -208,7 +244,8 @@ where } } Some(Err(e)) => { - tracing::warn!("sse stream error: {e}"); + this.connector + .handle_stream_error(&e, this.last_event_id.as_deref()); let retrying = this .connector .retry_connection(this.last_event_id.as_deref()); diff --git a/crates/rmcp/src/transport/sse_client.rs b/crates/rmcp/src/transport/sse_client.rs index 7a0705c2..7b61e3d8 100644 --- a/crates/rmcp/src/transport/sse_client.rs +++ b/crates/rmcp/src/transport/sse_client.rs @@ -1,9 +1,12 @@ -//! reference: https://html.spec.whatwg.org/multipage/server-sent-events.html -use std::{pin::Pin, sync::Arc}; +//! Reference: +use std::{ + pin::Pin, + sync::{Arc, RwLock}, +}; use futures::{StreamExt, future::BoxFuture}; use http::Uri; -use sse_stream::Error as SseError; +use sse_stream::{Error as SseError, Sse}; use thiserror::Error; use super::{ @@ -54,9 +57,13 @@ pub trait SseClient: Clone + Send + Sync + 'static { ) -> impl Future>> + Send + '_; } +/// Helper that refreshes the POST endpoint whenever the server emits +/// control frames during SSE reconnect; used together with +/// [`SseAutoReconnectStream`]. struct SseClientReconnect { pub client: C, pub uri: Uri, + pub message_endpoint: Arc>, } impl SseStreamReconnect for SseClientReconnect { @@ -68,6 +75,37 @@ impl SseStreamReconnect for SseClientReconnect { let last_event_id = last_event_id.map(|s| s.to_owned()); Box::pin(async move { client.get_stream(uri, last_event_id, None).await }) } + + fn handle_control_event(&mut self, event: &Sse) -> Result<(), Self::Error> { + if event.event.as_deref() != Some("endpoint") { + return Ok(()); + } + let Some(data) = event.data.as_ref() else { + return Ok(()); + }; + // Servers typically resend the message POST endpoint (often with a new + // sessionId) when a stream reconnects. Reuse `message_endpoint` helper + // to resolve it and update the shared URI. + let new_endpoint = message_endpoint(self.uri.clone(), data.clone()) + .map_err(SseTransportError::InvalidUri)?; + *self + .message_endpoint + .write() + .expect("message endpoint lock poisoned") = new_endpoint; + Ok(()) + } + + fn handle_stream_error( + &mut self, + error: &(dyn std::error::Error + 'static), + last_event_id: Option<&str>, + ) { + tracing::warn!( + uri = %self.uri, + last_event_id = last_event_id.unwrap_or(""), + "sse stream error: {error}" + ); + } } type ServerMessageStream = Pin>>>; @@ -81,7 +119,7 @@ type ServerMessageStream = Pin = Pin = Pin { client: C, config: SseClientConfig, - message_endpoint: Uri, + /// Current POST endpoint; refreshed when the server sends new endpoint + /// control frames. + message_endpoint: Arc>, stream: Option>, } @@ -168,8 +208,16 @@ impl Transport for SseClientTransport { item: crate::service::TxJsonRpcMessage, ) -> impl Future> + Send + 'static { let client = self.client.clone(); - let uri = self.message_endpoint.clone(); - async move { client.post_message(uri, item, None).await } + let message_endpoint = self.message_endpoint.clone(); + async move { + let uri = { + let guard = message_endpoint + .read() + .expect("message endpoint lock poisoned"); + guard.clone() + }; + client.post_message(uri, item, None).await + } } async fn close(&mut self) -> Result<(), Self::Error> { self.stream.take(); @@ -194,7 +242,7 @@ impl SseClientTransport { let sse_endpoint = config.sse_endpoint.as_ref().parse::()?; let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?; - let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() { + let initial_message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() { let ep = endpoint.parse::()?; let mut sse_endpoint_parts = sse_endpoint.clone().into_parts(); sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query; @@ -214,12 +262,14 @@ impl SseClientTransport { break message_endpoint(sse_endpoint.clone(), ep)?; } }; + let message_endpoint = Arc::new(RwLock::new(initial_message_endpoint)); let stream = Box::pin(SseAutoReconnectStream::new( sse_stream, SseClientReconnect { client: client.clone(), uri: sse_endpoint.clone(), + message_endpoint: message_endpoint.clone(), }, config.retry_policy.clone(), )); @@ -274,7 +324,7 @@ pub struct SseClientConfig { /// and the server send the message endpoint event as `message?session_id=123`, /// then the message endpoint will be `http://example.com/message`. /// - /// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL) + /// This follows the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/en-US/docs/Web/API/URL/URL) pub sse_endpoint: Arc, pub retry_policy: Arc, /// if this is settled, the client will use this endpoint to send message and skip get the endpoint event @@ -293,8 +343,40 @@ impl Default for SseClientConfig { #[cfg(test)] mod tests { + use futures::StreamExt; + use serde_json::{Value, json}; + use super::*; + #[derive(Clone)] + struct DummyClient; + + #[derive(Debug, thiserror::Error)] + #[error("dummy error")] + struct DummyError; + + impl SseClient for DummyClient { + type Error = DummyError; + + async fn post_message( + &self, + _uri: Uri, + _message: ClientJsonRpcMessage, + _auth_token: Option, + ) -> Result<(), SseTransportError> { + Ok(()) + } + + async fn get_stream( + &self, + _uri: Uri, + _last_event_id: Option, + _auth_token: Option, + ) -> Result> { + unreachable!("get_stream should not be called in this test") + } + } + #[test] fn test_message_endpoint() { let base_url = "https://localhost/sse".parse::().unwrap(); @@ -319,4 +401,58 @@ mod tests { .unwrap(); assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x"); } + + #[test] + fn handle_endpoint_control_event_updates_uri() { + let initial_endpoint = "https://example.com/message?sessionId=old" + .parse::() + .unwrap(); + let shared_endpoint = Arc::new(RwLock::new(initial_endpoint)); + let mut reconnect = SseClientReconnect { + client: DummyClient, + uri: "https://example.com/sse".parse::().unwrap(), + message_endpoint: shared_endpoint.clone(), + }; + + let control_event = Sse::default() + .event("endpoint") + .data("/message?sessionId=new"); + + reconnect.handle_control_event(&control_event).unwrap(); + + let guard = shared_endpoint.read().expect("lock poisoned"); + assert_eq!( + guard.to_string(), + "https://example.com/message?sessionId=new" + ); + } + + #[tokio::test] + async fn control_event_frames_are_skipped() { + let payload = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": {"ok": true} + }) + .to_string(); + + let events = vec![ + Ok(Sse::default() + .event("endpoint") + .data("/message?sessionId=reconnect")), + Ok(Sse::default().event("message").data(payload.clone())), + ]; + + let sse_src: BoxedSseResponse = futures::stream::iter(events).boxed(); + let reconn_stream = SseAutoReconnectStream::never_reconnect(sse_src, DummyError); + futures::pin_mut!(reconn_stream); + + let message = reconn_stream.next().await.expect("stream item").unwrap(); + let actual: Value = serde_json::to_value(message).expect("serialize actual message"); + // We only need to assert that a valid JSON-RPC response came through after + // skipping control frames. The exact `result` shape depends on the SDK's + // typed result enums and is not asserted here. + assert_eq!(actual.get("jsonrpc"), Some(&Value::String("2.0".into()))); + assert_eq!(actual.get("id"), Some(&Value::Number(1u64.into()))); + } } diff --git a/target_ci_mirror/CACHEDIR.TAG b/target_ci_mirror/CACHEDIR.TAG new file mode 100644 index 00000000..20d7c319 --- /dev/null +++ b/target_ci_mirror/CACHEDIR.TAG @@ -0,0 +1,3 @@ +Signature: 8a477f597d28d172789f06886806bc55 +# This file is a cache directory tag created by cargo. +# For information about cache directory tags see https://bford.info/cachedir/