diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java
index 750828adb..e277e4749 100644
--- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java
+++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java
@@ -440,4 +440,4 @@ public WebFluxStreamableServerTransportProvider build() {
}
-}
+}
\ No newline at end of file
diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java
new file mode 100644
index 000000000..4d2dc62f4
--- /dev/null
+++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java
@@ -0,0 +1,803 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.server.DefaultMcpTransportContext;
+import io.modelcontextprotocol.server.McpTransportContext;
+import io.modelcontextprotocol.server.McpTransportContextExtractor;
+import io.modelcontextprotocol.spec.HttpHeaders;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpStreamableServerSession;
+import io.modelcontextprotocol.spec.McpStreamableServerTransport;
+import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
+import io.modelcontextprotocol.util.Assert;
+import jakarta.servlet.AsyncContext;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.annotation.WebServlet;
+import jakarta.servlet.http.HttpServlet;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import reactor.core.publisher.Mono;
+
+/**
+ * Server-side implementation of the Model Context Protocol (MCP) streamable transport
+ * layer using HTTP with Server-Sent Events (SSE) through HttpServlet. This implementation
+ * provides a bridge between synchronous HttpServlet operations and reactive programming
+ * patterns to maintain compatibility with the reactive transport interface.
+ *
+ *
+ * This is the HttpServlet equivalent of
+ * {@link io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider}
+ * for the core MCP module, providing streamable HTTP transport functionality without
+ * Spring dependencies.
+ *
+ * @author Zachary German
+ * @author Christian Tzolov
+ * @author Dariusz Jędrzejczyk
+ * @see McpStreamableServerTransportProvider
+ * @see HttpServlet
+ */
+@WebServlet(asyncSupported = true)
+public class HttpServletStreamableServerTransportProvider extends HttpServlet
+ implements McpStreamableServerTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(HttpServletStreamableServerTransportProvider.class);
+
+ /**
+ * Event type for JSON-RPC messages sent through the SSE connection.
+ */
+ public static final String MESSAGE_EVENT_TYPE = "message";
+
+ /**
+ * Event type for sending the message endpoint URI to clients.
+ */
+ public static final String ENDPOINT_EVENT_TYPE = "endpoint";
+
+ /**
+ * Header name for the response media types accepted by the requester.
+ */
+ private static final String ACCEPT = "Accept";
+
+ public static final String UTF_8 = "UTF-8";
+
+ public static final String APPLICATION_JSON = "application/json";
+
+ public static final String TEXT_EVENT_STREAM = "text/event-stream";
+
+ public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
+
+ /**
+ * The endpoint URI where clients should send their JSON-RPC messages. Defaults to
+ * "/mcp".
+ */
+ private final String mcpEndpoint;
+
+ /**
+ * Flag indicating whether DELETE requests are disallowed on the endpoint.
+ */
+ private final boolean disallowDelete;
+
+ private final ObjectMapper objectMapper;
+
+ private McpStreamableServerSession.Factory sessionFactory;
+
+ /**
+ * Map of active client sessions, keyed by mcp-session-id.
+ */
+ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
+
+ private McpTransportContextExtractor contextExtractor;
+
+ /**
+ * Flag indicating if the transport is shutting down.
+ */
+ private volatile boolean isClosing = false;
+
+ /**
+ * Constructs a new HttpServletStreamableServerTransportProvider instance.
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ * of messages.
+ * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
+ * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
+ * @param disallowDelete Whether to disallow DELETE requests on the endpoint.
+ * @param contextExtractor The extractor for transport context from the request.
+ * @throws IllegalArgumentException if any parameter is null
+ */
+ private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
+ boolean disallowDelete, McpTransportContextExtractor contextExtractor) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
+ Assert.notNull(contextExtractor, "Context extractor must not be null");
+
+ this.objectMapper = objectMapper;
+ this.mcpEndpoint = mcpEndpoint;
+ this.disallowDelete = disallowDelete;
+ this.contextExtractor = contextExtractor;
+ }
+
+ @Override
+ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
+ }
+
+ /**
+ * Broadcasts a notification to all connected clients through their SSE connections.
+ * If any errors occur during sending to a particular client, they are logged but
+ * don't prevent sending to other clients.
+ * @param method The method name for the notification
+ * @param params The parameters for the notification
+ * @return A Mono that completes when the broadcast attempt is finished
+ */
+ @Override
+ public Mono notifyClients(String method, Object params) {
+ if (this.sessions.isEmpty()) {
+ logger.debug("No active sessions to broadcast message to");
+ return Mono.empty();
+ }
+
+ logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
+
+ return Mono.fromRunnable(() -> {
+ this.sessions.values().parallelStream().forEach(session -> {
+ try {
+ session.sendNotification(method, params).block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage());
+ }
+ });
+ });
+ }
+
+ /**
+ * Initiates a graceful shutdown of the transport.
+ * @return A Mono that completes when all cleanup operations are finished
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ this.isClosing = true;
+ logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
+
+ this.sessions.values().parallelStream().forEach(session -> {
+ try {
+ session.closeGracefully().block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to close session {}: {}", session.getId(), e.getMessage());
+ }
+ });
+
+ this.sessions.clear();
+ logger.debug("Graceful shutdown completed");
+ });
+ }
+
+ /**
+ * Handles GET requests to establish SSE connections and message replay.
+ * @param request The HTTP servlet request
+ * @param response The HTTP servlet response
+ * @throws ServletException If a servlet-specific error occurs
+ * @throws IOException If an I/O error occurs
+ */
+ @Override
+ protected void doGet(HttpServletRequest request, HttpServletResponse response)
+ throws ServletException, IOException {
+
+ String requestURI = request.getRequestURI();
+ if (!requestURI.endsWith(mcpEndpoint)) {
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return;
+ }
+
+ if (this.isClosing) {
+ response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
+ return;
+ }
+
+ List badRequestErrors = new ArrayList<>();
+
+ String accept = request.getHeader(ACCEPT);
+ if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) {
+ badRequestErrors.add("text/event-stream required in Accept header");
+ }
+
+ String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID);
+
+ if (sessionId == null || sessionId.isBlank()) {
+ badRequestErrors.add("Session ID required in mcp-session-id header");
+ }
+
+ if (!badRequestErrors.isEmpty()) {
+ String combinedMessage = String.join("; ", badRequestErrors);
+ this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage));
+ return;
+ }
+
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return;
+ }
+
+ logger.debug("Handling GET request for session: {}", sessionId);
+
+ McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
+
+ try {
+ response.setContentType(TEXT_EVENT_STREAM);
+ response.setCharacterEncoding(UTF_8);
+ response.setHeader("Cache-Control", "no-cache");
+ response.setHeader("Connection", "keep-alive");
+ response.setHeader("Access-Control-Allow-Origin", "*");
+
+ AsyncContext asyncContext = request.startAsync();
+ asyncContext.setTimeout(0);
+
+ HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport(
+ sessionId, asyncContext, response.getWriter());
+
+ // Check if this is a replay request
+ if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) {
+ String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID);
+
+ try {
+ session.replay(lastId)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .toIterable()
+ .forEach(message -> {
+ try {
+ sessionTransport.sendMessage(message)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to replay message: {}", e.getMessage());
+ asyncContext.complete();
+ }
+ });
+ }
+ catch (Exception e) {
+ logger.error("Failed to replay messages: {}", e.getMessage());
+ asyncContext.complete();
+ }
+ }
+ else {
+ // Establish new listening stream
+ McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
+ .listeningStream(sessionTransport);
+
+ asyncContext.addListener(new jakarta.servlet.AsyncListener() {
+ @Override
+ public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException {
+ logger.debug("SSE connection completed for session: {}", sessionId);
+ listeningStream.close();
+ }
+
+ @Override
+ public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException {
+ logger.debug("SSE connection timed out for session: {}", sessionId);
+ listeningStream.close();
+ }
+
+ @Override
+ public void onError(jakarta.servlet.AsyncEvent event) throws IOException {
+ logger.debug("SSE connection error for session: {}", sessionId);
+ listeningStream.close();
+ }
+
+ @Override
+ public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException {
+ // No action needed
+ }
+ });
+ }
+ }
+ catch (Exception e) {
+ logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage());
+ response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
+ }
+ }
+
+ /**
+ * Handles POST requests for incoming JSON-RPC messages from clients.
+ * @param request The HTTP servlet request containing the JSON-RPC message
+ * @param response The HTTP servlet response
+ * @throws ServletException If a servlet-specific error occurs
+ * @throws IOException If an I/O error occurs
+ */
+ @Override
+ protected void doPost(HttpServletRequest request, HttpServletResponse response)
+ throws ServletException, IOException {
+
+ String requestURI = request.getRequestURI();
+ if (!requestURI.endsWith(mcpEndpoint)) {
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return;
+ }
+
+ if (this.isClosing) {
+ response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
+ return;
+ }
+
+ List badRequestErrors = new ArrayList<>();
+
+ String accept = request.getHeader(ACCEPT);
+ if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) {
+ badRequestErrors.add("text/event-stream required in Accept header");
+ }
+ if (accept == null || !accept.contains(APPLICATION_JSON)) {
+ badRequestErrors.add("application/json required in Accept header");
+ }
+
+ McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
+
+ try {
+ BufferedReader reader = request.getReader();
+ StringBuilder body = new StringBuilder();
+ String line;
+ while ((line = reader.readLine()) != null) {
+ body.append(line);
+ }
+
+ McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString());
+
+ // Handle initialization request
+ if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
+ && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
+ if (!badRequestErrors.isEmpty()) {
+ String combinedMessage = String.join("; ", badRequestErrors);
+ this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage));
+ return;
+ }
+
+ McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
+ new TypeReference() {
+ });
+ McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
+ .startSession(initializeRequest);
+ this.sessions.put(init.session().getId(), init.session());
+
+ try {
+ McpSchema.InitializeResult initResult = init.initResult().block();
+
+ response.setContentType(APPLICATION_JSON);
+ response.setCharacterEncoding(UTF_8);
+ response.setHeader(HttpHeaders.MCP_SESSION_ID, init.session().getId());
+ response.setStatus(HttpServletResponse.SC_OK);
+
+ String jsonResponse = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse(
+ McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null));
+
+ PrintWriter writer = response.getWriter();
+ writer.write(jsonResponse);
+ writer.flush();
+ return;
+ }
+ catch (Exception e) {
+ logger.error("Failed to initialize session: {}", e.getMessage());
+ this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
+ new McpError("Failed to initialize session: " + e.getMessage()));
+ return;
+ }
+ }
+
+ String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID);
+
+ if (sessionId == null || sessionId.isBlank()) {
+ badRequestErrors.add("Session ID required in mcp-session-id header");
+ }
+
+ if (!badRequestErrors.isEmpty()) {
+ String combinedMessage = String.join("; ", badRequestErrors);
+ this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage));
+ return;
+ }
+
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ this.responseError(response, HttpServletResponse.SC_NOT_FOUND,
+ new McpError("Session not found: " + sessionId));
+ return;
+ }
+
+ if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) {
+ session.accept(jsonrpcResponse)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ response.setStatus(HttpServletResponse.SC_ACCEPTED);
+ }
+ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
+ session.accept(jsonrpcNotification)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ response.setStatus(HttpServletResponse.SC_ACCEPTED);
+ }
+ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
+ // For streaming responses, we need to return SSE
+ response.setContentType(TEXT_EVENT_STREAM);
+ response.setCharacterEncoding(UTF_8);
+ response.setHeader("Cache-Control", "no-cache");
+ response.setHeader("Connection", "keep-alive");
+ response.setHeader("Access-Control-Allow-Origin", "*");
+
+ AsyncContext asyncContext = request.startAsync();
+ asyncContext.setTimeout(0);
+
+ HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport(
+ sessionId, asyncContext, response.getWriter());
+
+ try {
+ session.responseStream(jsonrpcRequest, sessionTransport)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to handle request stream: {}", e.getMessage());
+ asyncContext.complete();
+ }
+ }
+ else {
+ this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
+ new McpError("Unknown message type"));
+ }
+ }
+ catch (IllegalArgumentException | IOException e) {
+ logger.error("Failed to deserialize message: {}", e.getMessage());
+ this.responseError(response, HttpServletResponse.SC_BAD_REQUEST,
+ new McpError("Invalid message format: " + e.getMessage()));
+ }
+ catch (Exception e) {
+ logger.error("Error handling message: {}", e.getMessage());
+ try {
+ this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
+ new McpError("Error processing message: " + e.getMessage()));
+ }
+ catch (IOException ex) {
+ logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
+ response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message");
+ }
+ }
+ }
+
+ /**
+ * Handles DELETE requests for session deletion.
+ * @param request The HTTP servlet request
+ * @param response The HTTP servlet response
+ * @throws ServletException If a servlet-specific error occurs
+ * @throws IOException If an I/O error occurs
+ */
+ @Override
+ protected void doDelete(HttpServletRequest request, HttpServletResponse response)
+ throws ServletException, IOException {
+
+ String requestURI = request.getRequestURI();
+ if (!requestURI.endsWith(mcpEndpoint)) {
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return;
+ }
+
+ if (this.isClosing) {
+ response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down");
+ return;
+ }
+
+ if (this.disallowDelete) {
+ response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
+ return;
+ }
+
+ McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
+
+ if (request.getHeader(HttpHeaders.MCP_SESSION_ID) == null) {
+ this.responseError(response, HttpServletResponse.SC_BAD_REQUEST,
+ new McpError("Session ID required in mcp-session-id header"));
+ return;
+ }
+
+ String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID);
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ response.sendError(HttpServletResponse.SC_NOT_FOUND);
+ return;
+ }
+
+ try {
+ session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
+ this.sessions.remove(sessionId);
+ response.setStatus(HttpServletResponse.SC_OK);
+ }
+ catch (Exception e) {
+ logger.error("Failed to delete session {}: {}", sessionId, e.getMessage());
+ try {
+ this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
+ new McpError(e.getMessage()));
+ }
+ catch (IOException ex) {
+ logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage());
+ response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error deleting session");
+ }
+ }
+ }
+
+ public void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException {
+ response.setContentType(APPLICATION_JSON);
+ response.setCharacterEncoding(UTF_8);
+ response.setStatus(httpCode);
+ String jsonError = objectMapper.writeValueAsString(mcpError);
+ PrintWriter writer = response.getWriter();
+ writer.write(jsonError);
+ writer.flush();
+ return;
+ }
+
+ /**
+ * Sends an SSE event to a client with a specific ID.
+ * @param writer The writer to send the event through
+ * @param eventType The type of event (message or endpoint)
+ * @param data The event data
+ * @param id The event ID
+ * @throws IOException If an error occurs while writing the event
+ */
+ private void sendEvent(PrintWriter writer, String eventType, String data, String id) throws IOException {
+ if (id != null) {
+ writer.write("id: " + id + "\n");
+ }
+ writer.write("event: " + eventType + "\n");
+ writer.write("data: " + data + "\n\n");
+ writer.flush();
+
+ if (writer.checkError()) {
+ throw new IOException("Client disconnected");
+ }
+ }
+
+ /**
+ * Cleans up resources when the servlet is being destroyed.
+ *
+ * This method ensures a graceful shutdown by closing all client connections before
+ * calling the parent's destroy method.
+ */
+ @Override
+ public void destroy() {
+ closeGracefully().block();
+ super.destroy();
+ }
+
+ /**
+ * Implementation of McpStreamableServerTransport for HttpServlet SSE sessions. This
+ * class handles the transport-level communication for a specific client session.
+ *
+ *
+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the
+ * underlying PrintWriter to prevent race conditions when multiple threads attempt to
+ * send messages concurrently.
+ */
+
+ private class HttpServletStreamableMcpSessionTransport implements McpStreamableServerTransport {
+
+ private final String sessionId;
+
+ private final AsyncContext asyncContext;
+
+ private final PrintWriter writer;
+
+ private volatile boolean closed = false;
+
+ private final ReentrantLock lock = new ReentrantLock();
+
+ /**
+ * Creates a new session transport with the specified ID and SSE writer.
+ * @param sessionId The unique identifier for this session
+ * @param asyncContext The async context for the session
+ * @param writer The writer for sending server events to the client
+ */
+ HttpServletStreamableMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) {
+ this.sessionId = sessionId;
+ this.asyncContext = asyncContext;
+ this.writer = writer;
+ logger.debug("Streamable session transport {} initialized with SSE writer", sessionId);
+ }
+
+ /**
+ * Sends a JSON-RPC message to the client through the SSE connection.
+ * @param message The JSON-RPC message to send
+ * @return A Mono that completes when the message has been sent
+ */
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ return sendMessage(message, null);
+ }
+
+ /**
+ * Sends a JSON-RPC message to the client through the SSE connection with a
+ * specific message ID.
+ * @param message The JSON-RPC message to send
+ * @param messageId The message ID for SSE event identification
+ * @return A Mono that completes when the message has been sent
+ */
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) {
+ return Mono.fromRunnable(() -> {
+ if (this.closed) {
+ logger.debug("Attempted to send message to closed session: {}", this.sessionId);
+ return;
+ }
+
+ lock.lock();
+ try {
+ if (this.closed) {
+ logger.debug("Session {} was closed during message send attempt", this.sessionId);
+ return;
+ }
+
+ String jsonText = objectMapper.writeValueAsString(message);
+ HttpServletStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText,
+ messageId != null ? messageId : this.sessionId);
+ logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId);
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
+ HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
+ this.asyncContext.complete();
+ }
+ finally {
+ lock.unlock();
+ }
+ });
+ }
+
+ /**
+ * Converts data from one type to another using the configured ObjectMapper.
+ * @param data The source data object to convert
+ * @param typeRef The target type reference
+ * @return The converted object of type T
+ * @param The target type
+ */
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ /**
+ * Initiates a graceful shutdown of the transport.
+ * @return A Mono that completes when the shutdown is complete
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ HttpServletStreamableMcpSessionTransport.this.close();
+ });
+ }
+
+ /**
+ * Closes the transport immediately.
+ */
+ @Override
+ public void close() {
+ lock.lock();
+ try {
+ if (this.closed) {
+ logger.debug("Session transport {} already closed", this.sessionId);
+ return;
+ }
+
+ this.closed = true;
+
+ // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
+ this.asyncContext.complete();
+ logger.debug("Successfully completed async context for session {}", sessionId);
+ }
+ catch (Exception e) {
+ logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage());
+ }
+ finally {
+ lock.unlock();
+ }
+ }
+
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Builder for creating instances of
+ * {@link HttpServletStreamableServerTransportProvider}.
+ */
+ public static class Builder {
+
+ private ObjectMapper objectMapper;
+
+ private String mcpEndpoint = "/mcp";
+
+ private boolean disallowDelete = false;
+
+ private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context;
+
+ /**
+ * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
+ * messages.
+ * @param objectMapper The ObjectMapper instance. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if objectMapper is null
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Sets the endpoint URI where clients should send their JSON-RPC messages.
+ * @param mcpEndpoint The MCP endpoint URI. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if mcpEndpoint is null
+ */
+ public Builder mcpEndpoint(String mcpEndpoint) {
+ Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
+ this.mcpEndpoint = mcpEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets whether to disallow DELETE requests on the endpoint.
+ * @param disallowDelete true to disallow DELETE requests, false otherwise
+ * @return this builder instance
+ */
+ public Builder disallowDelete(boolean disallowDelete) {
+ this.disallowDelete = disallowDelete;
+ return this;
+ }
+
+ /**
+ * Sets the context extractor for extracting transport context from the request.
+ * @param contextExtractor The context extractor to use. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if contextExtractor is null
+ */
+ public Builder contextExtractor(McpTransportContextExtractor contextExtractor) {
+ Assert.notNull(contextExtractor, "Context extractor must not be null");
+ this.contextExtractor = contextExtractor;
+ return this;
+ }
+
+ /**
+ * Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
+ * with the configured settings.
+ * @return A new HttpServletStreamableServerTransportProvider instance
+ * @throws IllegalStateException if required parameters are not set
+ */
+ public HttpServletStreamableServerTransportProvider build() {
+ Assert.notNull(this.objectMapper, "ObjectMapper must be set");
+ Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
+
+ return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint,
+ this.disallowDelete, this.contextExtractor);
+ }
+
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java
new file mode 100644
index 000000000..687ff6ae9
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java
@@ -0,0 +1,1271 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ */
+package io.modelcontextprotocol.server;
+
+import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
+import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.awaitility.Awaitility.await;
+import static org.mockito.Mockito.mock;
+
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
+
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
+import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
+import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
+import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
+import io.modelcontextprotocol.spec.McpSchema.ElicitRequest;
+import io.modelcontextprotocol.spec.McpSchema.ElicitResult;
+import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
+import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
+import io.modelcontextprotocol.spec.McpSchema.Role;
+import io.modelcontextprotocol.spec.McpSchema.Root;
+import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
+import io.modelcontextprotocol.spec.McpSchema.Tool;
+import net.javacrumbs.jsonunit.core.Option;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+public abstract class AbstractMcpClientServerIntegrationTests {
+
+ protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>();
+
+ abstract protected void prepareClients(int port, String mcpEndpoint);
+
+ abstract protected McpServer.AsyncSpecification> prepareAsyncServerBuilder();
+
+ abstract protected McpServer.SyncSpecification> prepareSyncServerBuilder();
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void simple(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(1000))
+ .build();
+
+ try (
+ // Create client without sampling capabilities
+ var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
+ .requestTimeout(Duration.ofSeconds(1000))
+ .build()) {
+
+ assertThat(client.initialize()).isNotNull();
+
+ }
+ server.closeGracefully();
+ }
+
+ // ---------------------------------------
+ // Sampling Tests
+ // ---------------------------------------
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+ exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block();
+ return Mono.just(mock(CallToolResult.class));
+ })
+ .build();
+
+ var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build();
+
+ try (
+ // Create client without sampling capabilities
+ var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
+ .build()) {
+
+ assertThat(client.initialize()).isNotNull();
+
+ try {
+ client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }
+ catch (McpError e) {
+ assertThat(e).isInstanceOf(McpError.class)
+ .hasMessage("Client must be configured with sampling capabilities");
+ }
+ }
+ server.closeGracefully();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateMessageSuccess(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ Function samplingHandler = request -> {
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
+
+ return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
+ CreateMessageResult.StopReason.STOP_SEQUENCE);
+ };
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ var createMessageRequest = McpSchema.CreateMessageRequest.builder()
+ .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
+ new McpSchema.TextContent("Test message"))))
+ .modelPreferences(ModelPreferences.builder()
+ .hints(List.of())
+ .costPriority(1.0)
+ .speedPriority(1.0)
+ .intelligencePriority(1.0)
+ .build())
+ .build();
+
+ StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.role()).isEqualTo(Role.USER);
+ assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
+ assertThat(result.model()).isEqualTo("MockModelName");
+ assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ })
+ .build();
+
+ //@formatter:off
+ var mcpServer = prepareAsyncServerBuilder()
+ .serverInfo("test-server", "1.0.0")
+ .tools(tool)
+ .build();
+
+ try (
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().sampling().build())
+ .sampling(samplingHandler)
+ .build()) {//@formatter:on
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+
+ assertThat(response).isNotNull().isEqualTo(callResponse);
+ }
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException {
+
+ // Client
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ Function samplingHandler = request -> {
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
+ try {
+ TimeUnit.SECONDS.sleep(2);
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
+ CreateMessageResult.StopReason.STOP_SEQUENCE);
+ };
+
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().sampling().build())
+ .sampling(samplingHandler)
+ .build();
+
+ // Server
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ var createMessageRequest = McpSchema.CreateMessageRequest.builder()
+ .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
+ new McpSchema.TextContent("Test message"))))
+ .modelPreferences(ModelPreferences.builder()
+ .hints(List.of())
+ .costPriority(1.0)
+ .speedPriority(1.0)
+ .intelligencePriority(1.0)
+ .build())
+ .build();
+
+ StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.role()).isEqualTo(Role.USER);
+ assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
+ assertThat(result.model()).isEqualTo("MockModelName");
+ assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ })
+ .build();
+
+ var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(4))
+ .tools(tool)
+ .build();
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response).isEqualTo(callResponse);
+
+ mcpClient.close();
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ Function samplingHandler = request -> {
+ assertThat(request.messages()).hasSize(1);
+ assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
+ try {
+ TimeUnit.SECONDS.sleep(2);
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
+ CreateMessageResult.StopReason.STOP_SEQUENCE);
+ };
+
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().sampling().build())
+ .sampling(samplingHandler)
+ .build();
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ var createMessageRequest = McpSchema.CreateMessageRequest.builder()
+ .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
+ new McpSchema.TextContent("Test message"))))
+ .modelPreferences(ModelPreferences.builder()
+ .hints(List.of())
+ .costPriority(1.0)
+ .speedPriority(1.0)
+ .intelligencePriority(1.0)
+ .build())
+ .build();
+
+ StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.role()).isEqualTo(Role.USER);
+ assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
+ assertThat(result.model()).isEqualTo("MockModelName");
+ assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ })
+ .build();
+
+ var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(1))
+ .tools(tool)
+ .build();
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
+ mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }).withMessageContaining("Timeout");
+
+ mcpClient.close();
+ mcpServer.close();
+ }
+
+ // ---------------------------------------
+ // Elicitation Tests
+ // ---------------------------------------
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateElicitationWithoutElicitationCapabilities(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block();
+
+ return Mono.just(mock(CallToolResult.class));
+ })
+ .build();
+
+ var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build();
+
+ try (
+ // Create client without elicitation capabilities
+ var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {
+
+ assertThat(client.initialize()).isNotNull();
+
+ try {
+ client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }
+ catch (McpError e) {
+ assertThat(e).isInstanceOf(McpError.class)
+ .hasMessage("Client must be configured with elicitation capabilities");
+ }
+ }
+ server.closeGracefully().block();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateElicitationSuccess(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ Function elicitationHandler = request -> {
+ assertThat(request.message()).isNotEmpty();
+ assertThat(request.requestedSchema()).isNotNull();
+
+ return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT,
+ Map.of("message", request.message()));
+ };
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ var elicitationRequest = McpSchema.ElicitRequest.builder()
+ .message("Test message")
+ .requestedSchema(
+ Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
+ .build();
+
+ StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT);
+ assertThat(result.content().get("message")).isEqualTo("Test message");
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ })
+ .build();
+
+ var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build();
+
+ try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().elicitation().build())
+ .elicitation(elicitationHandler)
+ .build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response).isEqualTo(callResponse);
+ }
+ mcpServer.closeGracefully().block();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateElicitationWithRequestTimeoutSuccess(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ Function elicitationHandler = request -> {
+ assertThat(request.message()).isNotEmpty();
+ assertThat(request.requestedSchema()).isNotNull();
+ try {
+ TimeUnit.SECONDS.sleep(2);
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT,
+ Map.of("message", request.message()));
+ };
+
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().elicitation().build())
+ .elicitation(elicitationHandler)
+ .build();
+
+ CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
+ null);
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ var elicitationRequest = McpSchema.ElicitRequest.builder()
+ .message("Test message")
+ .requestedSchema(
+ Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
+ .build();
+
+ StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
+ assertThat(result).isNotNull();
+ assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT);
+ assertThat(result.content().get("message")).isEqualTo("Test message");
+ }).verifyComplete();
+
+ return Mono.just(callResponse);
+ })
+ .build();
+
+ var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(3))
+ .tools(tool)
+ .build();
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+
+ assertThat(response).isNotNull();
+ assertThat(response).isEqualTo(callResponse);
+
+ mcpClient.closeGracefully();
+ mcpServer.closeGracefully().block();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testCreateElicitationWithRequestTimeoutFail(String clientType) {
+
+ var latch = new CountDownLatch(1);
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ Function elicitationHandler = request -> {
+ assertThat(request.message()).isNotEmpty();
+ assertThat(request.requestedSchema()).isNotNull();
+
+ try {
+ if (!latch.await(2, TimeUnit.SECONDS)) {
+ throw new RuntimeException("Timeout waiting for elicitation processing");
+ }
+ }
+ catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message()));
+ };
+
+ var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
+ .capabilities(ClientCapabilities.builder().elicitation().build())
+ .elicitation(elicitationHandler)
+ .build();
+
+ CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
+
+ AtomicReference resultRef = new AtomicReference<>();
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ var elicitationRequest = ElicitRequest.builder()
+ .message("Test message")
+ .requestedSchema(
+ Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
+ .build();
+
+ return exchange.createElicitation(elicitationRequest)
+ .doOnNext(resultRef::set)
+ .then(Mono.just(callResponse));
+ })
+ .build();
+
+ var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(1)) // 1 second.
+ .tools(tool)
+ .build();
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
+ mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }).withMessageContaining("within 1000ms");
+
+ ElicitResult elicitResult = resultRef.get();
+ assertThat(elicitResult).isNull();
+
+ mcpClient.closeGracefully();
+ mcpServer.closeGracefully().block();
+ }
+
+ // ---------------------------------------
+ // Roots Tests
+ // ---------------------------------------
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testRootsSuccess(String clientType) {
+ var clientBuilder = clientBuilders.get(clientType);
+
+ List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2"));
+
+ AtomicReference> rootsRef = new AtomicReference<>();
+
+ var mcpServer = prepareSyncServerBuilder()
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
+ .build();
+
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ .roots(roots)
+ .build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ assertThat(rootsRef.get()).isNull();
+
+ mcpClient.rootsListChangedNotification();
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(roots);
+ });
+
+ // Remove a root
+ mcpClient.removeRoot(roots.get(0).uri());
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(roots.get(1)));
+ });
+
+ // Add a new root
+ var root3 = new Root("uri3://", "root3");
+ mcpClient.addRoot(root3);
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3));
+ });
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testRootsWithoutCapability(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ exchange.listRoots(); // try to list roots
+
+ return mock(CallToolResult.class);
+ })
+ .build();
+
+ var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> {
+ }).tools(tool).build();
+
+ try (
+ // Create client without roots capability
+ // No roots capability
+ var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) {
+
+ assertThat(mcpClient.initialize()).isNotNull();
+
+ // Attempt to list roots should fail
+ try {
+ mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+ }
+ catch (McpError e) {
+ assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported");
+ }
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testRootsNotificationWithEmptyRootsList(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ AtomicReference> rootsRef = new AtomicReference<>();
+
+ var mcpServer = prepareSyncServerBuilder()
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
+ .build();
+
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ .roots(List.of()) // Empty roots list
+ .build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ mcpClient.rootsListChangedNotification();
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).isEmpty();
+ });
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testRootsWithMultipleHandlers(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ List roots = List.of(new Root("uri1://", "root1"));
+
+ AtomicReference> rootsRef1 = new AtomicReference<>();
+ AtomicReference> rootsRef2 = new AtomicReference<>();
+
+ var mcpServer = prepareSyncServerBuilder()
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate))
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate))
+ .build();
+
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ .roots(roots)
+ .build()) {
+
+ assertThat(mcpClient.initialize()).isNotNull();
+
+ mcpClient.rootsListChangedNotification();
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef1.get()).containsAll(roots);
+ assertThat(rootsRef2.get()).containsAll(roots);
+ });
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testRootsServerCloseWithActiveSubscription(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ List roots = List.of(new Root("uri1://", "root1"));
+
+ AtomicReference> rootsRef = new AtomicReference<>();
+
+ var mcpServer = prepareSyncServerBuilder()
+ .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
+ .build();
+
+ try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
+ .roots(roots)
+ .build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ mcpClient.rootsListChangedNotification();
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(roots);
+ });
+ }
+
+ mcpServer.close();
+ }
+
+ // ---------------------------------------
+ // Tools Tests
+ // ---------------------------------------
+
+ String emptyJsonSchema = """
+ {
+ "$schema": "http://json-schema.org/draft-07/schema#",
+ "type": "object",
+ "properties": {}
+ }
+ """;
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testToolCallSuccess(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
+ McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+
+ try {
+ HttpResponse response = HttpClient.newHttpClient()
+ .send(HttpRequest.newBuilder()
+ .uri(URI.create(
+ "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md"))
+ .GET()
+ .build(), HttpResponse.BodyHandlers.ofString());
+ String responseBody = response.body();
+ assertThat(responseBody).isNotBlank();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ return callResponse;
+ })
+ .build();
+
+ var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(tool1)
+ .build();
+
+ try (var mcpClient = clientBuilder.build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
+
+ CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
+
+ assertThat(response).isNotNull().isEqualTo(callResponse);
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ McpSyncServer mcpServer = prepareSyncServerBuilder()
+ .capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(McpServerFeatures.SyncToolSpecification.builder()
+ .tool(Tool.builder()
+ .name("tool1")
+ .description("tool1 description")
+ .inputSchema(emptyJsonSchema)
+ .build())
+ .callHandler((exchange, request) -> {
+ // We trigger a timeout on blocking read, raising an exception
+ Mono.never().block(Duration.ofSeconds(1));
+ return null;
+ })
+ .build())
+ .build();
+
+ try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) {
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ // We expect the tool call to fail immediately with the exception raised by
+ // the offending tool
+ // instead of getting back a timeout.
+ assertThatExceptionOfType(McpError.class)
+ .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())))
+ .withMessageContaining("Timeout on blocking read");
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testToolListChangeHandlingSuccess(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
+ McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build())
+ .callHandler((exchange, request) -> {
+ // perform a blocking call to a remote service
+ try {
+ HttpResponse response = HttpClient.newHttpClient()
+ .send(HttpRequest.newBuilder()
+ .uri(URI.create(
+ "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md"))
+ .GET()
+ .build(), HttpResponse.BodyHandlers.ofString());
+ String responseBody = response.body();
+ assertThat(responseBody).isNotBlank();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+ return callResponse;
+ })
+ .build();
+
+ AtomicReference> rootsRef = new AtomicReference<>();
+
+ var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(tool1)
+ .build();
+
+ try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> {
+ // perform a blocking call to a remote service
+ try {
+ HttpResponse response = HttpClient.newHttpClient()
+ .send(HttpRequest.newBuilder()
+ .uri(URI.create(
+ "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md"))
+ .GET()
+ .build(), HttpResponse.BodyHandlers.ofString());
+ String responseBody = response.body();
+ assertThat(responseBody).isNotBlank();
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ rootsRef.set(toolsUpdate);
+ }).build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ assertThat(rootsRef.get()).isNull();
+
+ assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
+
+ mcpServer.notifyToolsListChanged();
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(tool1.tool()));
+ });
+
+ // Remove a tool
+ mcpServer.removeTool("tool1");
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).isEmpty();
+ });
+
+ // Add a new tool
+ McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(Tool.builder()
+ .name("tool2")
+ .description("tool2 description")
+ .inputSchema(emptyJsonSchema)
+ .build())
+ .callHandler((exchange, request) -> callResponse)
+ .build();
+
+ mcpServer.addTool(tool2);
+
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(rootsRef.get()).containsAll(List.of(tool2.tool()));
+ });
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testInitialize(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ var mcpServer = prepareSyncServerBuilder().build();
+
+ try (var mcpClient = clientBuilder.build()) {
+
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testPingSuccess(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ // Create server with a tool that uses ping functionality
+ AtomicReference executionOrder = new AtomicReference<>("");
+
+ McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
+ .tool(Tool.builder()
+ .name("ping-async-test")
+ .description("Test ping async behavior")
+ .inputSchema(emptyJsonSchema)
+ .build())
+ .callHandler((exchange, request) -> {
+
+ executionOrder.set(executionOrder.get() + "1");
+
+ // Test async ping behavior
+ return exchange.ping().doOnNext(result -> {
+
+ assertThat(result).isNotNull();
+ // Ping should return an empty object or map
+ assertThat(result).isInstanceOf(Map.class);
+
+ executionOrder.set(executionOrder.get() + "2");
+ assertThat(result).isNotNull();
+ }).then(Mono.fromCallable(() -> {
+ executionOrder.set(executionOrder.get() + "3");
+ return new CallToolResult("Async ping test completed", false);
+ }));
+ })
+ .build();
+
+ var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(tool)
+ .build();
+
+ try (var mcpClient = clientBuilder.build()) {
+
+ // Initialize client
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ // Call the tool that tests ping async behavior
+ CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of()));
+ assertThat(result).isNotNull();
+ assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed");
+
+ // Verify execution order
+ assertThat(executionOrder.get()).isEqualTo("123");
+ }
+
+ mcpServer.close();
+ }
+
+ // ---------------------------------------
+ // Tool Structured Output Schema Tests
+ // ---------------------------------------
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testStructuredOutputValidationSuccess(String clientType) {
+ var clientBuilder = clientBuilders.get(clientType);
+
+ // Create a tool with output schema
+ Map outputSchema = Map.of(
+ "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation",
+ Map.of("type", "string"), "timestamp", Map.of("type", "string")),
+ "required", List.of("result", "operation"));
+
+ Tool calculatorTool = Tool.builder()
+ .name("calculator")
+ .description("Performs mathematical calculations")
+ .outputSchema(outputSchema)
+ .build();
+
+ McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(calculatorTool)
+ .callHandler((exchange, request) -> {
+ String expression = (String) request.arguments().getOrDefault("expression", "2 + 3");
+ double result = evaluateExpression(expression);
+ return CallToolResult.builder()
+ .structuredContent(
+ Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z"))
+ .build();
+ })
+ .build();
+
+ var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(tool)
+ .build();
+
+ try (var mcpClient = clientBuilder.build()) {
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ // Verify tool is listed with output schema
+ var toolsList = mcpClient.listTools();
+ assertThat(toolsList.tools()).hasSize(1);
+ assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator");
+ // Note: outputSchema might be null in sync server, but validation still works
+
+ // Call tool with valid structured output
+ CallToolResult response = mcpClient
+ .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));
+
+ assertThat(response).isNotNull();
+ assertThat(response.isError()).isFalse();
+
+ // In WebMVC, structured content is returned properly
+ if (response.structuredContent() != null) {
+ assertThat(response.structuredContent()).containsEntry("result", 5.0)
+ .containsEntry("operation", "2 + 3")
+ .containsEntry("timestamp", "2024-01-01T10:00:00Z");
+ }
+ else {
+ // Fallback to checking content if structured content is not available
+ assertThat(response.content()).isNotEmpty();
+ }
+
+ assertThat(response.structuredContent()).isNotNull();
+ assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER)
+ .when(Option.IGNORING_EXTRA_ARRAY_ITEMS)
+ .isObject()
+ .isEqualTo(json("""
+ {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}"""));
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testStructuredOutputValidationFailure(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ // Create a tool with output schema
+ Map outputSchema = Map.of("type", "object", "properties",
+ Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required",
+ List.of("result", "operation"));
+
+ Tool calculatorTool = Tool.builder()
+ .name("calculator")
+ .description("Performs mathematical calculations")
+ .outputSchema(outputSchema)
+ .build();
+
+ McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(calculatorTool)
+ .callHandler((exchange, request) -> {
+ // Return invalid structured output. Result should be number, missing
+ // operation
+ return CallToolResult.builder()
+ .addTextContent("Invalid calculation")
+ .structuredContent(Map.of("result", "not-a-number", "extra", "field"))
+ .build();
+ })
+ .build();
+
+ var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(tool)
+ .build();
+
+ try (var mcpClient = clientBuilder.build()) {
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ // Call tool with invalid structured output
+ CallToolResult response = mcpClient
+ .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));
+
+ assertThat(response).isNotNull();
+ assertThat(response.isError()).isTrue();
+ assertThat(response.content()).hasSize(1);
+ assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
+
+ String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text();
+ assertThat(errorMessage).contains("Validation failed");
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testStructuredOutputMissingStructuredContent(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ // Create a tool with output schema
+ Map outputSchema = Map.of("type", "object", "properties",
+ Map.of("result", Map.of("type", "number")), "required", List.of("result"));
+
+ Tool calculatorTool = Tool.builder()
+ .name("calculator")
+ .description("Performs mathematical calculations")
+ .outputSchema(outputSchema)
+ .build();
+
+ McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(calculatorTool)
+ .callHandler((exchange, request) -> {
+ // Return result without structured content but tool has output schema
+ return CallToolResult.builder().addTextContent("Calculation completed").build();
+ })
+ .build();
+
+ var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .capabilities(ServerCapabilities.builder().tools(true).build())
+ .tools(tool)
+ .build();
+
+ try (var mcpClient = clientBuilder.build()) {
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ // Call tool that should return structured content but doesn't
+ CallToolResult response = mcpClient
+ .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));
+
+ assertThat(response).isNotNull();
+ assertThat(response.isError()).isTrue();
+ assertThat(response.content()).hasSize(1);
+ assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
+
+ String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text();
+ assertThat(errorMessage).isEqualTo(
+ "Response missing structured content which is expected when calling tool with non-empty outputSchema");
+ }
+
+ mcpServer.close();
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient" })
+ void testStructuredOutputRuntimeToolAddition(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ // Start server without tools
+ var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0")
+ .capabilities(ServerCapabilities.builder().tools(true).build())
+ .build();
+
+ try (var mcpClient = clientBuilder.build()) {
+ InitializeResult initResult = mcpClient.initialize();
+ assertThat(initResult).isNotNull();
+
+ // Initially no tools
+ assertThat(mcpClient.listTools().tools()).isEmpty();
+
+ // Add tool with output schema at runtime
+ Map outputSchema = Map.of("type", "object", "properties",
+ Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required",
+ List.of("message", "count"));
+
+ Tool dynamicTool = Tool.builder()
+ .name("dynamic-tool")
+ .description("Dynamically added tool")
+ .outputSchema(outputSchema)
+ .build();
+
+ McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder()
+ .tool(dynamicTool)
+ .callHandler((exchange, request) -> {
+ int count = (Integer) request.arguments().getOrDefault("count", 1);
+ return CallToolResult.builder()
+ .addTextContent("Dynamic tool executed " + count + " times")
+ .structuredContent(Map.of("message", "Dynamic execution", "count", count))
+ .build();
+ })
+ .build();
+
+ // Add tool to server
+ mcpServer.addTool(toolSpec);
+
+ // Wait for tool list change notification
+ await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
+ assertThat(mcpClient.listTools().tools()).hasSize(1);
+ });
+
+ // Verify tool was added with output schema
+ var toolsList = mcpClient.listTools();
+ assertThat(toolsList.tools()).hasSize(1);
+ assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool");
+ // Note: outputSchema might be null in sync server, but validation still works
+
+ // Call dynamically added tool
+ CallToolResult response = mcpClient
+ .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3)));
+
+ assertThat(response).isNotNull();
+ assertThat(response.isError()).isFalse();
+
+ assertThat(response.content()).hasSize(1);
+ assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
+ assertThat(((McpSchema.TextContent) response.content().get(0)).text())
+ .isEqualTo("Dynamic tool executed 3 times");
+
+ assertThat(response.structuredContent()).isNotNull();
+ assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER)
+ .when(Option.IGNORING_EXTRA_ARRAY_ITEMS)
+ .isObject()
+ .isEqualTo(json("""
+ {"count":3,"message":"Dynamic execution"}"""));
+ }
+
+ mcpServer.close();
+ }
+
+ private double evaluateExpression(String expression) {
+ // Simple expression evaluator for testing
+ return switch (expression) {
+ case "2 + 3" -> 5.0;
+ case "10 * 2" -> 20.0;
+ case "7 + 8" -> 15.0;
+ case "5 + 3" -> 8.0;
+ default -> 0.0;
+ };
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java
new file mode 100644
index 000000000..327ec1b21
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server;
+
+import org.junit.jupiter.api.Timeout;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
+
+/**
+ * Tests for {@link McpAsyncServer} using
+ * {@link HttpServletStreamableServerTransportProvider}.
+ *
+ * @author Christian Tzolov
+ */
+@Timeout(15)
+class HttpServletStreamableAsyncServerTests extends AbstractMcpAsyncServerTests {
+
+ protected McpStreamableServerTransportProvider createMcpTransportProvider() {
+ return HttpServletStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .mcpEndpoint("/mcp/message")
+ .build();
+ }
+
+ @Override
+ protected McpServer.AsyncSpecification> prepareAsyncServerBuilder() {
+ return McpServer.async(createMcpTransportProvider());
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java
new file mode 100644
index 000000000..3377f98a6
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ */
+package io.modelcontextprotocol.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.time.Duration;
+
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.apache.catalina.startup.Tomcat;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
+import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
+import io.modelcontextprotocol.server.McpServer.SyncSpecification;
+import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
+import io.modelcontextprotocol.server.transport.TomcatTestUtil;
+
+class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests {
+
+ private static final int PORT = TomcatTestUtil.findAvailablePort();
+
+ private static final String MESSAGE_ENDPOINT = "/mcp/message";
+
+ private HttpServletStreamableServerTransportProvider mcpServerTransportProvider;
+
+ private Tomcat tomcat;
+
+ @BeforeEach
+ public void before() {
+ // Create and configure the transport provider
+ mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .mcpEndpoint(MESSAGE_ENDPOINT)
+ .build();
+
+ tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider);
+ try {
+ tomcat.start();
+ assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ clientBuilders
+ .put("httpclient",
+ McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
+ .endpoint(MESSAGE_ENDPOINT)
+ .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10)));
+ }
+
+ @Override
+ protected AsyncSpecification> prepareAsyncServerBuilder() {
+ return McpServer.async(this.mcpServerTransportProvider);
+ }
+
+ @Override
+ protected SyncSpecification> prepareSyncServerBuilder() {
+ return McpServer.sync(this.mcpServerTransportProvider);
+ }
+
+ @AfterEach
+ public void after() {
+ if (mcpServerTransportProvider != null) {
+ mcpServerTransportProvider.closeGracefully().block();
+ }
+ if (tomcat != null) {
+ try {
+ tomcat.stop();
+ tomcat.destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+ @Override
+ protected void prepareClients(int port, String mcpEndpoint) {
+ }
+
+}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java
new file mode 100644
index 000000000..66fa2b2ac
--- /dev/null
+++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server;
+
+import org.junit.jupiter.api.Timeout;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
+
+/**
+ * Tests for {@link McpSyncServer} using
+ * {@link HttpServletStreamableServerTransportProvider}.
+ *
+ * @author Christian Tzolov
+ */
+@Timeout(15)
+class HttpServletStreamableSyncServerTests extends AbstractMcpSyncServerTests {
+
+ protected McpStreamableServerTransportProvider createMcpTransportProvider() {
+ return HttpServletStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .mcpEndpoint("/mcp/message")
+ .build();
+ }
+
+ @Override
+ protected McpServer.SyncSpecification> prepareSyncServerBuilder() {
+ return McpServer.sync(createMcpTransportProvider());
+ }
+
+}