Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/rmcp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
//! start a MCP server in Python and then list the tools and call `git status`
//! as follows:
//!
//! ```rust
//! ```rust,ignore
//! use anyhow::Result;
//! use rmcp::{model::CallToolRequestParam, service::ServiceExt, transport::{TokioChildProcess, ConfigureCommandExt}};
//! use tokio::process::Command;
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
//!
//! ## Examples
//!
//! ```rust
//! ```rust,ignore
//! # use rmcp::{
//! # ServiceExt, serve_client, serve_server,
//! # };
Expand Down
47 changes: 42 additions & 5 deletions crates/rmcp/src/transport/common/client_side_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,29 @@ impl<E: std::error::Error + Send> SseStreamReconnect for NeverReconnect<E> {
}
}

/// Abstraction for SSE reconnection logic. Implementors can hook into
/// [`handle_control_event`](Self::handle_control_event) to consume control
/// frames (e.g. `event: endpoint`) that arrive when a server restarts an SSE
/// stream. The default implementation is a no-op, keeping existing behaviour
/// intact.
pub(crate) trait SseStreamReconnect {
type Error: std::error::Error;
type Future: Future<Output = Result<BoxedSseResponse, Self::Error>> + Send;
fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future;
fn handle_control_event(&mut self, _event: &Sse) -> Result<(), Self::Error> {
Ok(())
}
fn handle_stream_error(
&mut self,
error: &(dyn std::error::Error + 'static),
last_event_id: Option<&str>,
) {
if let Some(id) = last_event_id {
tracing::warn!(%id, "sse stream error: {error}");
} else {
tracing::warn!("sse stream error: {error}");
}
}
}

pin_project_lite::pin_project! {
Expand Down Expand Up @@ -189,14 +208,31 @@ where
*this.server_retry_interval =
Some(Duration::from_millis(new_server_retry));
}
if let Some(event_id) = sse.id {
*this.last_event_id = Some(event_id);
if let Some(ref event_id) = sse.id {
*this.last_event_id = Some(event_id.clone());
}
// Only treat blank/`message` events as JSON-RPC payloads.
// Other control frames (endpoint, ping, etc.) are passed to
// the reconnection handler.
let is_message_event =
matches!(sse.event.as_deref(), None | Some("") | Some("message"));
if !is_message_event {
match this.connector.handle_control_event(&sse) {
Ok(()) => return self.poll_next(cx),
Err(e) => {
this.state.set(SseAutoReconnectStreamState::Terminated);
return Poll::Ready(Some(Err(e)));
}
}
}
if let Some(data) = sse.data {
match serde_json::from_str::<ServerJsonRpcMessage>(&data) {
Err(e) => {
// not sure should this be a hard error
tracing::warn!("failed to deserialize server message: {e}");
// Downgrade to debug to avoid noisy logs when servers emit
// non-JSON payloads as message frames. Include last_event_id
// to aid troubleshooting while keeping default behaviour.
let last_id = this.last_event_id.as_deref().unwrap_or("");
tracing::debug!(last_event_id=%last_id, "failed to deserialize server message: {e}");
return self.poll_next(cx);
}
Ok(message) => {
Expand All @@ -208,7 +244,8 @@ where
}
}
Some(Err(e)) => {
tracing::warn!("sse stream error: {e}");
this.connector
.handle_stream_error(&e, this.last_event_id.as_deref());
let retrying = this
.connector
.retry_connection(this.last_event_id.as_deref());
Expand Down
156 changes: 146 additions & 10 deletions crates/rmcp/src/transport/sse_client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! reference: https://html.spec.whatwg.org/multipage/server-sent-events.html
use std::{pin::Pin, sync::Arc};
//! Reference: <https://html.spec.whatwg.org/multipage/server-sent-events.html>
use std::{
pin::Pin,
sync::{Arc, RwLock},
};

use futures::{StreamExt, future::BoxFuture};
use http::Uri;
use sse_stream::Error as SseError;
use sse_stream::{Error as SseError, Sse};
use thiserror::Error;

use super::{
Expand Down Expand Up @@ -54,9 +57,13 @@ pub trait SseClient: Clone + Send + Sync + 'static {
) -> impl Future<Output = Result<BoxedSseResponse, SseTransportError<Self::Error>>> + Send + '_;
}

/// Helper that refreshes the POST endpoint whenever the server emits
/// control frames during SSE reconnect; used together with
/// [`SseAutoReconnectStream`].
struct SseClientReconnect<C> {
pub client: C,
pub uri: Uri,
pub message_endpoint: Arc<RwLock<Uri>>,
}

impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
Expand All @@ -68,6 +75,37 @@ impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
let last_event_id = last_event_id.map(|s| s.to_owned());
Box::pin(async move { client.get_stream(uri, last_event_id, None).await })
}

fn handle_control_event(&mut self, event: &Sse) -> Result<(), Self::Error> {
if event.event.as_deref() != Some("endpoint") {
return Ok(());
}
let Some(data) = event.data.as_ref() else {
return Ok(());
};
// Servers typically resend the message POST endpoint (often with a new
// sessionId) when a stream reconnects. Reuse `message_endpoint` helper
// to resolve it and update the shared URI.
let new_endpoint = message_endpoint(self.uri.clone(), data.clone())
.map_err(SseTransportError::InvalidUri)?;
*self
.message_endpoint
.write()
.expect("message endpoint lock poisoned") = new_endpoint;
Ok(())
}

fn handle_stream_error(
&mut self,
error: &(dyn std::error::Error + 'static),
last_event_id: Option<&str>,
) {
tracing::warn!(
uri = %self.uri,
last_event_id = last_event_id.unwrap_or(""),
"sse stream error: {error}"
);
}
}
type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<C>>>>;

Expand All @@ -81,7 +119,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
///
/// ## Using reqwest
///
/// ```rust
/// ```rust,ignore
/// use rmcp::transport::SseClientTransport;
///
/// // Enable the reqwest feature in Cargo.toml:
Expand All @@ -95,7 +133,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
///
/// ## Using a custom HTTP client
///
/// ```rust
/// ```rust,ignore
/// use rmcp::transport::sse_client::{SseClient, SseClientTransport, SseClientConfig};
/// use std::sync::Arc;
/// use futures::stream::BoxStream;
Expand Down Expand Up @@ -154,7 +192,9 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
pub struct SseClientTransport<C: SseClient> {
client: C,
config: SseClientConfig,
message_endpoint: Uri,
/// Current POST endpoint; refreshed when the server sends new endpoint
/// control frames.
message_endpoint: Arc<RwLock<Uri>>,
stream: Option<ServerMessageStream<C>>,
}

Expand All @@ -168,8 +208,16 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
item: crate::service::TxJsonRpcMessage<RoleClient>,
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
let client = self.client.clone();
let uri = self.message_endpoint.clone();
async move { client.post_message(uri, item, None).await }
let message_endpoint = self.message_endpoint.clone();
async move {
let uri = {
let guard = message_endpoint
.read()
.expect("message endpoint lock poisoned");
guard.clone()
};
client.post_message(uri, item, None).await
}
}
async fn close(&mut self) -> Result<(), Self::Error> {
self.stream.take();
Expand All @@ -194,7 +242,7 @@ impl<C: SseClient> SseClientTransport<C> {
let sse_endpoint = config.sse_endpoint.as_ref().parse::<http::Uri>()?;

let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
let initial_message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
let ep = endpoint.parse::<http::Uri>()?;
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query;
Expand All @@ -214,12 +262,14 @@ impl<C: SseClient> SseClientTransport<C> {
break message_endpoint(sse_endpoint.clone(), ep)?;
}
};
let message_endpoint = Arc::new(RwLock::new(initial_message_endpoint));

let stream = Box::pin(SseAutoReconnectStream::new(
sse_stream,
SseClientReconnect {
client: client.clone(),
uri: sse_endpoint.clone(),
message_endpoint: message_endpoint.clone(),
},
config.retry_policy.clone(),
));
Expand Down Expand Up @@ -274,7 +324,7 @@ pub struct SseClientConfig {
/// and the server send the message endpoint event as `message?session_id=123`,
/// then the message endpoint will be `http://example.com/message`.
///
/// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL)
/// This follows the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/en-US/docs/Web/API/URL/URL)
pub sse_endpoint: Arc<str>,
pub retry_policy: Arc<dyn SseRetryPolicy>,
/// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
Expand All @@ -293,8 +343,40 @@ impl Default for SseClientConfig {

#[cfg(test)]
mod tests {
use futures::StreamExt;
use serde_json::{Value, json};

use super::*;

#[derive(Clone)]
struct DummyClient;

#[derive(Debug, thiserror::Error)]
#[error("dummy error")]
struct DummyError;

impl SseClient for DummyClient {
type Error = DummyError;

async fn post_message(
&self,
_uri: Uri,
_message: ClientJsonRpcMessage,
_auth_token: Option<String>,
) -> Result<(), SseTransportError<Self::Error>> {
Ok(())
}

async fn get_stream(
&self,
_uri: Uri,
_last_event_id: Option<String>,
_auth_token: Option<String>,
) -> Result<BoxedSseResponse, SseTransportError<Self::Error>> {
unreachable!("get_stream should not be called in this test")
}
}

#[test]
fn test_message_endpoint() {
let base_url = "https://localhost/sse".parse::<http::Uri>().unwrap();
Expand All @@ -319,4 +401,58 @@ mod tests {
.unwrap();
assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x");
}

#[test]
fn handle_endpoint_control_event_updates_uri() {
let initial_endpoint = "https://example.com/message?sessionId=old"
.parse::<Uri>()
.unwrap();
let shared_endpoint = Arc::new(RwLock::new(initial_endpoint));
let mut reconnect = SseClientReconnect {
client: DummyClient,
uri: "https://example.com/sse".parse::<Uri>().unwrap(),
message_endpoint: shared_endpoint.clone(),
};

let control_event = Sse::default()
.event("endpoint")
.data("/message?sessionId=new");

reconnect.handle_control_event(&control_event).unwrap();

let guard = shared_endpoint.read().expect("lock poisoned");
assert_eq!(
guard.to_string(),
"https://example.com/message?sessionId=new"
);
}

#[tokio::test]
async fn control_event_frames_are_skipped() {
let payload = json!({
"jsonrpc": "2.0",
"id": 1,
"result": {"ok": true}
})
.to_string();

let events = vec![
Ok(Sse::default()
.event("endpoint")
.data("/message?sessionId=reconnect")),
Ok(Sse::default().event("message").data(payload.clone())),
];

let sse_src: BoxedSseResponse = futures::stream::iter(events).boxed();
let reconn_stream = SseAutoReconnectStream::never_reconnect(sse_src, DummyError);
futures::pin_mut!(reconn_stream);

let message = reconn_stream.next().await.expect("stream item").unwrap();
let actual: Value = serde_json::to_value(message).expect("serialize actual message");
// We only need to assert that a valid JSON-RPC response came through after
// skipping control frames. The exact `result` shape depends on the SDK's
// typed result enums and is not asserted here.
assert_eq!(actual.get("jsonrpc"), Some(&Value::String("2.0".into())));
assert_eq!(actual.get("id"), Some(&Value::Number(1u64.into())));
}
}
3 changes: 3 additions & 0 deletions target_ci_mirror/CACHEDIR.TAG
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Signature: 8a477f597d28d172789f06886806bc55
# This file is a cache directory tag created by cargo.
# For information about cache directory tags see https://bford.info/cachedir/
Loading