diff --git a/mcp-bom/pom.xml b/mcp-bom/pom.xml index 7214dacda..83d8bc510 100644 --- a/mcp-bom/pom.xml +++ b/mcp-bom/pom.xml @@ -7,7 +7,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT mcp-bom diff --git a/mcp-spring/mcp-spring-webflux/pom.xml b/mcp-spring/mcp-spring-webflux/pom.xml index fdec82377..300d518e7 100644 --- a/mcp-spring/mcp-spring-webflux/pom.xml +++ b/mcp-spring/mcp-spring-webflux/pom.xml @@ -6,13 +6,13 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT ../../pom.xml mcp-spring-webflux jar - WebFlux implementation of the Java MCP SSE transport - + WebFlux transports + WebFlux implementation for the SSE and Streamable Http Client and Server transports https://github.com/modelcontextprotocol/java-sdk @@ -25,13 +25,13 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT io.modelcontextprotocol.sdk mcp-test - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT test diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 53b59cb30..6d8e82f51 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -1,3 +1,7 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client.transport; import java.io.IOException; @@ -23,13 +27,16 @@ import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -65,6 +72,8 @@ public class WebClientStreamableHttpTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebClientStreamableHttpTransport.class); + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; + private static final String DEFAULT_ENDPOINT = "/mcp"; /** @@ -102,6 +111,11 @@ private WebClientStreamableHttpTransport(ObjectMapper objectMapper, WebClient.Bu this.activeSession.set(createTransportSession()); } + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + } + /** * Create a stateful builder for creating {@link WebClientStreamableHttpTransport} * instances. @@ -127,12 +141,17 @@ public Mono connect(Function, Mono> onClose = sessionId -> sessionId == null ? Mono.empty() - : webClient.delete().uri(this.endpoint).headers(httpHeaders -> { - httpHeaders.add("mcp-session-id", sessionId); - }).retrieve().toBodilessEntity().onErrorComplete(e -> { - logger.warn("Got error when closing transport", e); - return true; - }).then(); + : webClient.delete() + .uri(this.endpoint) + .header(HttpHeaders.MCP_SESSION_ID, sessionId) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .retrieve() + .toBodilessEntity() + .onErrorComplete(e -> { + logger.warn("Got error when closing transport", e); + return true; + }) + .then(); return new DefaultMcpTransportSession(onClose); } @@ -185,10 +204,11 @@ private Mono reconnect(McpTransportStream stream) { Disposable connection = webClient.get() .uri(this.endpoint) .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); if (stream != null) { - stream.lastId().ifPresent(id -> httpHeaders.add("last-event-id", id)); + stream.lastId().ifPresent(id -> httpHeaders.add(HttpHeaders.LAST_EVENT_ID, id)); } }) .exchangeToFlux(response -> { @@ -244,14 +264,15 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { Disposable connection = webClient.post() .uri(this.endpoint) - .accept(MediaType.TEXT_EVENT_STREAM, MediaType.APPLICATION_JSON) + .accept(MediaType.APPLICATION_JSON, MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .headers(httpHeaders -> { - transportSession.sessionId().ifPresent(id -> httpHeaders.add("mcp-session-id", id)); + transportSession.sessionId().ifPresent(id -> httpHeaders.add(HttpHeaders.MCP_SESSION_ID, id)); }) .bodyValue(message) .exchangeToFlux(response -> { if (transportSession - .markInitialized(response.headers().asHttpHeaders().getFirst("mcp-session-id"))) { + .markInitialized(response.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID))) { // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. reconnect(null).contextWrite(sink.contextView()).subscribe(); @@ -287,7 +308,7 @@ else if (mediaType.isCompatibleWith(MediaType.APPLICATION_JSON)) { logger.trace("Received response to POST for session {}", sessionRepresentation); // communicate to caller the message was delivered sink.success(); - return responseFlux(response); + return directResponseFlux(message, response); } else { logger.warn("Unknown media type {} returned for POST in session {}", contentType, @@ -340,7 +361,8 @@ private Flux extractError(ClientResponse response, Str McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, McpSchema.JSONRPCResponse.class); jsonRpcError = jsonRpcResponse.error(); - toPropagate = new McpError(jsonRpcError); + toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) + : new McpError("Can't parse the jsonResponse " + jsonRpcResponse); } catch (IOException ex) { toPropagate = new RuntimeException("Sending request failed", e); @@ -384,14 +406,22 @@ private static String sessionIdOrPlaceholder(McpTransportSession transportSes return transportSession.sessionId().orElse("[missing_session_id]"); } - private Flux responseFlux(ClientResponse response) { + private Flux directResponseFlux(McpSchema.JSONRPCMessage sentMessage, + ClientResponse response) { return response.bodyToMono(String.class).>handle((responseMessage, s) -> { try { - McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, - responseMessage); - s.next(List.of(jsonRpcResponse)); + if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(responseMessage)) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, responseMessage); + s.complete(); + } + else { + McpSchema.JSONRPCMessage jsonRpcResponse = McpSchema.deserializeJsonRpcMessage(objectMapper, + responseMessage); + s.next(List.of(jsonRpcResponse)); + } } catch (IOException e) { + // TODO: this should be a McpTransportError s.error(e); } }).flatMapIterable(Function.identity()); @@ -423,7 +453,8 @@ private Tuple2, Iterable> parse(Serve } } else { - throw new McpError("Received unrecognized SSE event type: " + event.event()); + logger.debug("Received SSE event with type: {}", event); + return Tuples.of(Optional.empty(), List.of()); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java index 128cda4c3..75caebef0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java @@ -1,18 +1,23 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.client.transport; import java.io.IOException; +import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,6 +67,8 @@ public class WebFluxSseClientTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseClientTransport.class); + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; + /** * Event type for JSON-RPC messages received through the SSE connection. The server * sends messages with this event type to transmit JSON-RPC protocol data. @@ -166,6 +173,11 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, ObjectMappe this.sseEndpoint = sseEndpoint; } + @Override + public List protocolVersions() { + return List.of(MCP_PROTOCOL_VERSION); + } + /** * Establishes a connection to the MCP server using Server-Sent Events (SSE). This * method initiates the SSE connection and sets up the message processing pipeline. @@ -216,7 +228,8 @@ else if (MESSAGE_EVENT_TYPE.equals(event.event())) { } } else { - s.error(new McpError("Received unrecognized SSE event type: " + event.event())); + logger.debug("Received unrecognized SSE event type: {}", event); + s.complete(); } }).transform(handler)).subscribe(); @@ -249,6 +262,7 @@ public Mono sendMessage(JSONRPCMessage message) { return webClient.post() .uri(messageEndpointUri) .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .bodyValue(jsonText) .retrieve() .toBodilessEntity() @@ -281,6 +295,7 @@ protected Flux> eventStream() {// @formatter:off .get() .uri(this.sseEndpoint) .accept(MediaType.TEXT_EVENT_STREAM) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) .retrieve() .bodyToFlux(SSE_TYPE) .retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler))); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index fde067f03..aaf7bab46 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,6 +1,12 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; +import java.time.Duration; +import java.util.List; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; @@ -10,7 +16,10 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; @@ -76,6 +85,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + private static final String MCP_PROTOCOL_VERSION = "2025-06-18"; + /** * Default SSE endpoint path as specified by the MCP transport specification. */ @@ -109,6 +120,12 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private volatile boolean isClosing = false; + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + /** * Constructs a new WebFlux SSE server transport provider instance with the default * SSE endpoint. @@ -118,7 +135,10 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -131,7 +151,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); } @@ -145,9 +168,32 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * messages. This endpoint will be communicated to clients during SSE connection * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param sseEndpoint The SSE endpoint path. Must not be null. + * @param keepAliveInterval The interval for sending keep-alive pings to clients. + * @throws IllegalArgumentException if either parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -161,6 +207,22 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } @Override @@ -209,15 +271,6 @@ public Mono notifyClients(String method, Object params) { /** * Initiates a graceful shutdown of all the sessions. This method ensures all active * sessions are properly closed and cleaned up. - * - *

- * The shutdown process: - *

    - *
  • Marks the transport as closing to prevent new connections
  • - *
  • Closes each active session
  • - *
  • Removes closed sessions from the sessions map
  • - *
  • Times out after 5 seconds if shutdown takes too long
  • - *
* @return A Mono that completes when all sessions have been closed */ @Override @@ -225,7 +278,14 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()) .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) .flatMap(McpServerSession::closeGracefully) - .then(); + .then() + .doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -396,6 +456,8 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private Duration keepAliveInterval; + /** * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP * messages. @@ -446,6 +508,17 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the interval for sending keep-alive pings to clients. + * @param keepAliveInterval The keep-alive interval duration. If null, keep-alive + * is disabled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of {@link WebFluxSseServerTransportProvider} with the * configured settings. @@ -456,7 +529,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval); } } diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java new file mode 100644 index 000000000..23fff25b3 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStatelessServerTransport.java @@ -0,0 +1,216 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebFlux based {@link McpStatelessServerTransport}. + * + * @author Dariusz Jędrzejczyk + */ +public class WebFluxStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStatelessServerTransport.class); + + private final ObjectMapper objectMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebFluxStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint handling two HTTP methods: + *

    + *
  • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
  • + *
  • POST {messageEndpoint} - For handling client requests and notifications
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private Mono handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return this.mcpHandler.handleRequest(transportContext, jsonrpcRequest) + .flatMap(jsonrpcResponse -> ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .bodyValue(jsonrpcResponse)); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .then(ServerResponse.accepted().build()); + } + else { + return ServerResponse.badRequest() + .bodyValue(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStatelessServerTransport} with the + * configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStatelessServerTransport build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new WebFluxStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java new file mode 100644 index 000000000..f3f6c2c33 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -0,0 +1,487 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Implementation of a WebFlux based {@link McpStreamableServerTransportProvider}. + * + * @author Dariusz Jędrzejczyk + */ +public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class); + + public static final String MESSAGE_EVENT_TYPE = "message"; + + private final ObjectMapper objectMapper; + + private final String mcpEndpoint; + + private final boolean disallowDelete; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + private WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor, boolean disallowDelete, + Duration keepAliveInterval) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.disallowDelete = disallowDelete; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.isClosing = true; + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpStreamableServerSession::closeGracefully) + .then(); + }).then().doOnSuccess(v -> { + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint with three methods: + *

    + *
  • GET {messageEndpoint} - For the client listening SSE stream
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
  • DELETE {messageEndpoint} - For removing sessions
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Opens the listening SSE streams for clients. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleGet(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + return Mono.defer(() -> { + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().build(); + } + + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(session.replay(lastId), ServerSentEvent.class); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( + sink); + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + sink.onDispose(listeningStream::close); + }), ServerSentEvent.class); + + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Handles incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A Mono with the response appropriate to a particular Streamable HTTP flow. + */ + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + sessions.put(init.session().getId(), init.session()); + return init.initResult().map(initializeResult -> { + McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); + try { + return this.objectMapper.writeValueAsString(jsonrpcResponse); + } + catch (IOException e) { + logger.warn("Failed to serialize initResponse", e); + throw Exceptions.propagate(e); + } + }) + .flatMap(initResult -> ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .bodyValue(initResult)); + } + + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); + Mono stream = session.responseStream(jsonrpcRequest, st); + Disposable streamSubscription = stream.onErrorComplete(err -> { + sink.error(err); + return true; + }).contextWrite(sink.contextView()).subscribe(); + sink.onCancel(streamSubscription); + }), ServerSentEvent.class); + } + else { + return ServerResponse.badRequest().bodyValue(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }) + .switchIfEmpty(ServerResponse.badRequest().build()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private Mono handleDelete(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + return Mono.defer(() -> { + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + return session.delete().then(ServerResponse.ok().build()); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private class WebFluxStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final FluxSink> sink; + + public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return this.sendMessage(message, null); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .id(messageId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStreamableServerTransportProvider}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxStreamableServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private boolean disallowDelete; + + private Duration keepAliveInterval; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets whether the session removal capability is disabled. + * @param disallowDelete if {@code true}, the DELETE endpoint will not be + * supported and sessions won't be deleted. + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the keep-alive interval for the server transport. + * @param keepAliveInterval The interval for sending keep-alive messages. If null, + * no keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebFluxStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStreamableServerTransportProvider build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new WebFluxStreamableServerTransportProvider(objectMapper, mcpEndpoint, contextExtractor, + disallowDelete, keepAliveInterval); + } + + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index 23ddf6173..a1f1a8947 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -1,35 +1,15 @@ /* * Copyright 2024 - 2024 the original author or authors. */ -package io.modelcontextprotocol; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; +package io.modelcontextprotocol; import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; -import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.server.RouterFunctions; @@ -39,36 +19,14 @@ import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; import io.modelcontextprotocol.server.TestUtil; import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; -import io.modelcontextprotocol.spec.McpSchema.CompleteResult; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Prompt; -import io.modelcontextprotocol.spec.McpSchema.PromptArgument; -import io.modelcontextprotocol.spec.McpSchema.PromptReference; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import net.javacrumbs.jsonunit.core.Option; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; -import reactor.test.StepVerifier; -class WebFluxSseIntegrationTests { +class WebFluxSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -80,1413 +38,55 @@ class WebFluxSseIntegrationTests { private WebFluxSseServerTransportProvider mcpServerTransportProvider; - ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + @Override + protected void prepareClients(int port, String mcpEndpoint) { - @BeforeEach - public void before() { - - this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build(); - - HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); - ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); - this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); - clientBuilders.put("httpclient", - McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); clientBuilders.put("webflux", McpClient .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build())); - - } - - @AfterEach - public void after() { - if (httpServer != null) { - httpServer.disposeNow(); - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithoutSamplingCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) - .thenReturn(mock(CallToolResult.class))) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build();) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - return exchange.createMessage(createMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - AtomicReference samplingResult = new AtomicReference<>(); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) .build()) - .build(); - - return exchange.createMessage(craeteMessageRequest) - .doOnNext(samplingResult::set) - .thenReturn(callResponse); - }) - .build(); + .requestTimeout(Duration.ofHours(10))); - var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(4)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - assertWith(samplingResult.get(), result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }); - } - - mcpServer.closeGracefully().block(); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .build(); - - return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(1)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - } - - mcpServer.closeGracefully().block(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithoutElicitationCapabilities(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createElicitation(mock(ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { - - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCreateElicitationWithRequestTimeoutFail(String clientType) { - - var latch = new CountDownLatch(1); - // Client - var clientBuilder = clientBuilders.get(clientType); - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - try { - if (!latch.await(2, TimeUnit.SECONDS)) { - throw new RuntimeException("Timeout waiting for elicitation processing"); - } - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) // 1 second. - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("within 1000ms"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithoutCapability(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - // Create client without roots capability - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsNotificationWithEmptyRootsList(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsWithMultipleHandlers(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testRootsServerCloseWithActiveSubscription(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testToolListChangeHandlingSuccess(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testInitialize(String clientType) { - - var clientBuilder = clientBuilders.get(clientType); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testLoggingNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 3; - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - // Create and send notifications with different levels - - //@formatter:off - return exchange // This should be filtered out (DEBUG < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .then(exchange // This should be sent (NOTICE >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build())) - .then(exchange // This should be sent (ERROR > NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build())) - .then(exchange // This should be filtered out (INFO < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build())) - .then(exchange // This should be sent (ERROR >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); - //@formatter:on - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - } - mcpServer.close(); - } - - // --------------------------------------- - // Progress Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testProgressNotification(String clientType) throws InterruptedException { - int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token - CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(McpSchema.Tool.builder() - .name("progress-test") - .description("Test progress notifications") - .inputSchema(emptyJsonSchema) - .build()) - .callHandler((exchange, request) -> { - - // Create and send notifications - var progressToken = (String) request.meta().get("progressToken"); - - return exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) - .then(// Send a progress notification with another progress value - // should - exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", - 0.0, 1.0, "Another processing started"))) - .then(exchange.progressNotification( - new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) - .thenReturn(new CallToolResult(("Progress test completed"), false)); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try ( - // Create client with progress notification handler - var mcpClient = clientBuilder.progressConsumer(notification -> { - receivedNotifications.add(notification); - latch.countDown(); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that sends progress notifications - McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() - .name("progress-test") - .meta(Map.of("progressToken", "test-progress-token")) - .build(); - CallToolResult result = mcpClient.callTool(callToolRequest); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); - - assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); - - // Should have received 3 notifications - assertThat(receivedNotifications).hasSize(expectedNotificationsCount); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.message(), n -> n)); - - // First notification should be 0.0/1.0 progress - assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); - - // Second notification should be 0.5/1.0 progress - assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); - assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); - - // Third notification should be another progress token with 0.0/1.0 progress - assertThat(notificationMap.get("Another processing started").progressToken()) - .isEqualTo("another-progress-token"); - assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); - assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Another processing started").message()) - .isEqualTo("Another processing started"); - - // Fourth notification should be 1.0/1.0 progress - assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); - assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); - assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); - } - finally { - mcpServer.close(); - } - } - - // --------------------------------------- - // Completion Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : Completion call") - @ValueSource(strings = { "httpclient", "webflux" }) - void testCompletionShouldReturnExpectedSuggestions(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - var expectedValues = List.of("python", "pytorch", "pyside"); - var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total - true // hasMore - )); - - AtomicReference samplingRequest = new AtomicReference<>(); - BiFunction completionHandler = (mcpSyncServerExchange, - request) -> { - samplingRequest.set(request); - return completionResponse; - }; - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().completions().build()) - .prompts(new McpServerFeatures.SyncPromptSpecification( - new Prompt("code_review", "Code review", "this is code review prompt", - List.of(new PromptArgument("language", "Language", "string", false))), - (mcpSyncServerExchange, getPromptRequest) -> null)) - .completions(new McpServerFeatures.SyncCompletionSpecification( - new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CompleteRequest request = new CompleteRequest( - new PromptReference("ref/prompt", "code_review", "Code review"), - new CompleteRequest.CompleteArgument("language", "py")); - - CompleteResult result = mcpClient.completeCompletion(request); - - assertThat(result).isNotNull(); - - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Ping Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testPingSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationSuccess(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputValidationFailure(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputMissingStructuredContent(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); + @BeforeEach + public void before() { - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) .build(); - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - mcpServer.close(); + prepareClients(PORT, null); } - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "httpclient", "webflux" }) - void testStructuredOutputRuntimeToolAddition(String clientType) { - var clientBuilder = clientBuilders.get(clientType); - - // Start server without tools - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java new file mode 100644 index 000000000..302c58c5f --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStatelessIntegrationTests.java @@ -0,0 +1,86 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +class WebFluxStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStatelessServerTransport mcpStreamableServerTransport; + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + clientBuilders + .put("webflux", McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + } + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpStreamableServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpStreamableServerTransport); + } + + @BeforeEach + public void before() { + this.mcpStreamableServerTransport = WebFluxStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + prepareClients(PORT, null); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java new file mode 100644 index 000000000..616c6dcf8 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import java.time.Duration; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +class WebFluxStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpStreamableServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpStreamableServerTransportProvider); + } + + @BeforeEach + public void before() { + + this.mcpStreamableServerTransportProvider = WebFluxStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions + .toHttpHandler(mcpStreamableServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + prepareClients(PORT, null); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java index 7c4d35db8..191f10376 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientResiliencyTests.java @@ -1,3 +1,7 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java index 5ff707b3c..f8a16c153 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpAsyncClientTests.java @@ -1,3 +1,7 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import org.junit.jupiter.api.Timeout; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java index 70260c8bf..5e9960d0e 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebClientStreamableHttpSyncClientTests.java @@ -1,3 +1,7 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import org.junit.jupiter.api.Timeout; diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java index 42b91d14e..1cf5dffe2 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -77,6 +78,11 @@ public int getInboundMessageCount() { return inboundMessageCount.get(); } + public void simulateSseComment(String comment) { + events.tryEmitNext(ServerSentEvent.builder().comment(comment).build()); + inboundMessageCount.incrementAndGet(); + } + public void simulateEndpointEvent(String jsonMessage) { events.tryEmitNext(ServerSentEvent.builder().event("endpoint").data(jsonMessage).build()); inboundMessageCount.incrementAndGet(); @@ -158,6 +164,27 @@ void testBuilderPattern() { assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException(); } + @Test + void testCommentSseMessage() { + // If the line starts with a character (:) are comment lins and should be ingored + // https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation + + CopyOnWriteArrayList droppedErrors = new CopyOnWriteArrayList<>(); + reactor.core.publisher.Hooks.onErrorDropped(droppedErrors::add); + + try { + // Simulate receiving the SSE comment line + transport.simulateSseComment("sse comment"); + + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + + assertThat(droppedErrors).hasSize(0); + } + finally { + reactor.core.publisher.Hooks.resetOnErrorDropped(); + } + } + @Test void testMessageProcessing() { // Create a test message diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index cc33e7b94..a3bdf10b0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -29,8 +29,7 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); @@ -41,6 +40,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 2fc104538..3e28e96b8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -32,7 +32,11 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private WebFluxSseServerTransportProvider transportProvider; @Override - protected McpServerTransportProvider createMcpTransportProvider() { + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private McpServerTransportProvider createMcpTransportProvider() { transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java new file mode 100644 index 000000000..959f2f472 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using + * {@link WebFluxStreamableServerTransportProvider}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = WebFluxStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java new file mode 100644 index 000000000..3396d489c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using + * {@link WebFluxStreamableServerTransportProvider}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = WebFluxStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java index 0ab72a99f..dfb004e9b 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/transport/BlockingInputStream.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 4c6d37bf9..ea262d3a1 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -6,13 +6,13 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT ../../pom.xml mcp-spring-webmvc jar - Spring Web MVC implementation of the Java MCP SSE transport - + Spring Web MVC transports + Web MVC implementation for the SSE and Streamable Http Server transports https://github.com/modelcontextprotocol/java-sdk @@ -25,20 +25,27 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT + + + + org.springframework + spring-webmvc + ${springframework.version} io.modelcontextprotocol.sdk mcp-test - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT test - - org.springframework - spring-webmvc - ${springframework.version} + + io.modelcontextprotocol.sdk + mcp-spring-webflux + 0.12.0-SNAPSHOT + test diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 114eff607..ff452ca74 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,8 +6,10 @@ import java.io.IOException; import java.time.Duration; +import java.util.List; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -15,8 +17,11 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -106,6 +111,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi */ private volatile boolean isClosing = false; + private KeepAliveScheduler keepAliveScheduler; + /** * Constructs a new WebMvcSseServerTransportProvider instance with the default SSE * endpoint. @@ -115,7 +122,10 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi * messages via HTTP POST. This endpoint will be communicated to clients through the * SSE connection's initial endpoint event. * @throws IllegalArgumentException if either objectMapper or messageEndpoint is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) { this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT); } @@ -129,7 +139,10 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, "", messageEndpoint, sseEndpoint); } @@ -145,9 +158,33 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * SSE connection's initial endpoint event. * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * * @param keepAliveInterval The interval for sending keep-alive messages to + * @throws IllegalArgumentException if any parameter is null + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -161,6 +198,22 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } @Override @@ -208,10 +261,13 @@ public Mono closeGracefully() { return Flux.fromIterable(sessions.values()).doFirst(() -> { this.isClosing = true; logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - }) - .flatMap(McpServerSession::closeGracefully) - .then() - .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); + }).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + logger.debug("Graceful shutdown completed"); + sessions.clear(); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -339,6 +395,12 @@ private class WebMvcMcpSessionTransport implements McpServerTransport { private final SseBuilder sseBuilder; + /** + * Lock to ensure thread-safe access to the SSE builder when sending messages. + * This prevents concurrent modifications that could lead to corrupted SSE events. + */ + private final ReentrantLock sseBuilderLock = new ReentrantLock(); + /** * Creates a new session transport with the specified ID and SSE builder. * @param sessionId The unique identifier for this session @@ -358,6 +420,7 @@ private class WebMvcMcpSessionTransport implements McpServerTransport { @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { + sseBuilderLock.lock(); try { String jsonText = objectMapper.writeValueAsString(message); sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); @@ -367,6 +430,9 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); sseBuilder.error(e); } + finally { + sseBuilderLock.unlock(); + } }); } @@ -390,6 +456,7 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { public Mono closeGracefully() { return Mono.fromRunnable(() -> { logger.debug("Closing session transport: {}", sessionId); + sseBuilderLock.lock(); try { sseBuilder.complete(); logger.debug("Successfully completed SSE builder for session {}", sessionId); @@ -397,6 +464,9 @@ public Mono closeGracefully() { catch (Exception e) { logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); } + finally { + sseBuilderLock.unlock(); + } }); } @@ -405,6 +475,7 @@ public Mono closeGracefully() { */ @Override public void close() { + sseBuilderLock.lock(); try { sseBuilder.complete(); logger.debug("Successfully completed SSE builder for session {}", sessionId); @@ -412,6 +483,111 @@ public void close() { catch (Exception e) { logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); } + finally { + sseBuilderLock.unlock(); + } + } + + } + + /** + * Creates a new Builder instance for configuring and creating instances of + * WebMvcSseServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of WebMvcSseServerTransportProvider. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebMvcSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String baseUrl = ""; + + private String messageEndpoint; + + private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + + private Duration keepAliveInterval; + + /** + * Sets the JSON object mapper to use for message serialization/deserialization. + * @param objectMapper The object mapper to use + * @return This builder instance for method chaining + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the base URL for the server transport. + * @param baseUrl The base URL to use + * @return This builder instance for method chaining + */ + public Builder baseUrl(String baseUrl) { + Assert.notNull(baseUrl, "Base URL must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint path where clients will send their messages. + * @param messageEndpoint The message endpoint path + * @return This builder instance for method chaining + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); + this.messageEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the endpoint path where clients will establish SSE connections. + *

+ * If not specified, the default value of {@link #DEFAULT_SSE_ENDPOINT} will be + * used. + * @param sseEndpoint The SSE endpoint path + * @return This builder instance for method chaining + */ + public Builder sseEndpoint(String sseEndpoint) { + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + this.sseEndpoint = sseEndpoint; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

+ * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of WebMvcSseServerTransportProvider with the configured + * settings. + * @return A new WebMvcSseServerTransportProvider instance + * @throws IllegalStateException if objectMapper or messageEndpoint is not set + */ + public WebMvcSseServerTransportProvider build() { + if (messageEndpoint == null) { + throw new IllegalStateException("MessageEndpoint must be set"); + } + return new WebMvcSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java new file mode 100644 index 000000000..fef1920fc --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -0,0 +1,241 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebMVC based {@link McpStatelessServerTransport}. + * + *

+ * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport} + * + * @author Christian Tzolov + */ +public class WebMvcStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStatelessServerTransport.class); + + private final ObjectMapper objectMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebMvcStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebMVC router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint handling two HTTP methods: + *

    + *
  • GET {messageEndpoint} - Unsupported, returns 405 METHOD NOT ALLOWED
  • + *
  • POST {messageEndpoint} - For handling client requests and notifications
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private ServerResponse handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private ServerResponse handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).body(jsonrpcResponse); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + return ServerResponse.badRequest() + .body(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebMvcStatelessServerTransport with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStatelessServerTransport} with the + * configured settings. + * @return A new WebMvcStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStatelessServerTransport build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new WebMvcStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java new file mode 100644 index 000000000..fa51a0130 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -0,0 +1,690 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This + * implementation provides a bridge between synchronous WebMVC operations and reactive + * programming patterns to maintain compatibility with the reactive transport interface. + * + *

+ * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider} + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see RouterFunction + */ +public class WebMvcStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default base URL for the message endpoint. + */ + public static final String DEFAULT_BASE_URL = ""; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final ObjectMapper objectMapper; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new WebMvcStreamableServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @throws IllegalArgumentException if any parameter is null + */ + private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + + if (keepAliveInterval != null) { + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(this.sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles three endpoints: + *

    + *
  • GET [mcpEndpoint] - For establishing SSE connections and message replay
  • + *
  • POST [mcpEndpoint] - For receiving JSON-RPC messages from clients
  • + *
  • DELETE [mcpEndpoint] - For session deletion (if enabled)
  • + *
+ * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Setup the listening SSE connections and message replay. + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response + */ + private ServerResponse handleGet(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + logger.debug("Handling GET request for session: {}", sessionId); + + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + // Check if this is a replay request + if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + sseBuilder.error(e); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + sseBuilder.error(e); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + }); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handlePost(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) + || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { + return ServerResponse.badRequest() + .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + return ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .body(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, + null)); + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + // Handle other messages that require a session + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().body(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .body(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("Request response stream completed for session: {}", sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("Request response stream timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The incoming server request + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handleDelete(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + return ServerResponse.ok().build(); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class + * handles the transport-level communication for a specific client session. + * + *

+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying SSE builder to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + private class WebMvcStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + private final ReentrantLock lock = new ReentrantLock(); + + private volatile boolean closed = false; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Streamable session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = objectMapper.writeValueAsString(message); + this.sseBuilder.id(messageId != null ? messageId : this.sessionId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + try { + this.sseBuilder.error(e); + } + catch (Exception errorException) { + logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId, + errorException.getMessage()); + } + } + finally { + this.lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + WebMvcStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + this.sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + finally { + this.lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStreamableServerTransportProvider}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Duration keepAliveInterval; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be created to periodically check and send keep-alive messages to clients. + * @param keepAliveInterval The interval duration for keep-alive messages, or null + * to disable keep-alive + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebMvcStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStreamableServerTransportProvider build() { + Assert.notNull(this.objectMapper, "ObjectMapper must be set"); + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + + return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete, + this.contextExtractor, this.keepAliveInterval); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java new file mode 100644 index 000000000..66349216d --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint(MCP_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java new file mode 100644 index 000000000..cab487f12 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint(MCP_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 6a6ad17e9..bb4c2bf37 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -49,8 +49,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -90,6 +89,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 1b5218cc5..cce36d191 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -91,8 +91,15 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, - WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .baseUrl(CUSTOM_CONTEXT_PATH) + .messageEndpoint(MESSAGE_ENDPOINT) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + .build(); + // return new WebMvcSseServerTransportProvider(new ObjectMapper(), + // CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + // WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); } @Bean diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 9f2d6abff..995cbd165 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -4,67 +4,32 @@ package io.modelcontextprotocol.server; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; - -import net.javacrumbs.jsonunit.core.Option; - -class WebMvcSseIntegrationTests { +class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -72,7 +37,17 @@ class WebMvcSseIntegrationTests { private WebMvcSseServerTransportProvider mcpServerTransportProvider; - McpClient.SyncSpec clientBuilder; + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build()) + .requestTimeout(Duration.ofHours(10))); + } @Configuration @EnableWebMvc @@ -80,7 +55,10 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + return WebMvcSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); } @Bean @@ -105,7 +83,7 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); + prepareClients(PORT, MESSAGE_ENDPOINT); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); @@ -133,1102 +111,14 @@ public void after() { } } - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutSamplingCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - //@formatter:off - var server = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder - .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) {//@formatter:on - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @Test - void testCreateMessageSuccess() { - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - //@formatter:off - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) {//@formatter:on - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull().isEqualTo(callResponse); - } - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(4)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @Test - void testCreateElicitationWithoutElicitationCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @Test - void testCreateElicitationSuccess() { - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutSuccess() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutFail() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithoutCapability() { - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try ( - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @Test - void testRootsNotificationWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithMultipleHandlers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull().isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @Test - void testThrowingToolCallIsCaughtBeforeTimeout() { - McpSyncServer mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // We trigger a timeout on blocking read, raising an exception - Mono.never().block(Duration.ofSeconds(1)); - return null; - })) - .build(); - - try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // We expect the tool call to fail immediately with the exception raised by - // the offending tool - // instead of getting back a timeout. - assertThatExceptionOfType(McpError.class) - .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) - .withMessageContaining("Timeout on blocking read"); - } - - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @Test - void testInitialize() { - - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - @Test - void testPingSuccess() { - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema), - (exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - - @Test - void testStructuredOutputValidationSuccess() { - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - - // In WebMVC, structured content is returned properly - if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) - .containsEntry("operation", "2 + 3") - .containsEntry("timestamp", "2024-01-01T10:00:00Z"); - } - else { - // Fallback to checking content if structured content is not available - assertThat(response.content()).isNotEmpty(); - } - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputValidationFailure() { - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputMissingStructuredContent() { - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputRuntimeToolAddition() { - // Start server without tools - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); - } - - mcpServer.close(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 1964703c1..7e49ddf3b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -49,7 +49,11 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; @Override - protected WebMvcSseServerTransportProvider createMcpTransportProvider() { + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private WebMvcSseServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java new file mode 100644 index 000000000..802363d59 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -0,0 +1,130 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.AbstractStatelessIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import reactor.core.scheduler.Schedulers; + +class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStatelessServerTransport mcpServerTransport; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport statelessServerTransport) { + return statelessServerTransport.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransport); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + prepareClients(PORT, MESSAGE_ENDPOINT); + + // Get the transport from Spring context + this.mcpServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); + + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (this.mcpServerTransport != null) { + this.mcpServerTransport.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java new file mode 100644 index 000000000..800065915 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -0,0 +1,140 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import reactor.core.scheduler.Schedulers; + +class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStreamableServerTransportProvider mcpServerTransportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint(MESSAGE_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .build())); + + // Get the transport from Spring context + this.mcpServerTransportProvider = tomcatServer.appContext() + .getBean(WebMvcStreamableServerTransportProvider.class); + + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build()) + .requestTimeout(Duration.ofHours(10))); + } + +} diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f24d9fab2..563f60de9 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT mcp-test jar @@ -24,7 +24,7 @@ io.modelcontextprotocol.sdk mcp - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT @@ -91,6 +91,13 @@ ${logback.version} + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + + + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 000000000..26fd71d2b --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1538 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + return Mono.just(mock(CallToolResult.class)); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + server.closeGracefully(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without elicitation capabilities + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> toolsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + toolsRef.set(toolsUpdate); + } + catch (Exception e) { + e.printStackTrace(); + } + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(toolsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + ; + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + mcpServer.close(); + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(new CallToolResult(("Progress test completed"), false)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return new CallToolResult("Async ping test completed", false); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat(response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java new file mode 100644 index 000000000..618247d61 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -0,0 +1,539 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import reactor.core.publisher.Mono; + +public abstract class AbstractStatelessIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected StatelessAsyncSpecification prepareAsyncServerBuilder(); + + abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((ctx, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpStatelessSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((context, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((ctx, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + // Remove a tool + mcpServer.removeTool("tool1"); + + // Add a new tool + McpStatelessServerFeatures.SyncToolSpecification tool2 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat(response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + var tool = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + var toolSpec = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index 22e8f195b..ed34ebff6 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -1,6 +1,7 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.client; import eu.rekawek.toxiproxy.Proxy; diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index eb08bdcde..1e87d4420 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -21,6 +21,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -42,7 +44,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -63,28 +65,29 @@ void tearDown() { // Server Lifecycle Tests // --------------------------------------- - @Test - void testConstructorWithInvalidArguments() { + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "sse", "streamable" }) + void testConstructorWithInvalidArguments(String serverType) { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -104,8 +107,7 @@ void testImmediateClose() { @Deprecated void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -119,8 +121,7 @@ void testAddTool() { @Test void testAddToolCall() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -137,8 +138,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -158,8 +158,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -180,8 +179,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! @@ -203,8 +201,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -215,8 +212,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) @@ -235,8 +231,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -248,8 +243,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -264,8 +258,7 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -281,7 +274,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -290,7 +283,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier .create(mcpAsyncServer @@ -302,8 +295,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -319,8 +311,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -335,9 +326,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -353,9 +342,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -369,7 +356,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -378,8 +365,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -392,9 +378,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( @@ -410,9 +394,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -429,8 +411,7 @@ void testRemovePrompt() { prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -442,8 +423,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -466,8 +446,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -486,8 +465,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -500,8 +478,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -513,9 +490,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 4d5f9f772..5d70ae4c0 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -28,7 +28,7 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransportProvider} implementations. + * {@link McpServerTransportProvider} implementations. * * @author Christian Tzolov */ @@ -40,7 +40,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { } @@ -68,28 +68,28 @@ void testConstructorWithInvalidArguments() { .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -111,8 +111,7 @@ void testGetAsyncServer() { @Test @Deprecated void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -126,8 +125,7 @@ void testAddTool() { @Test void testAddToolCall() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -145,8 +143,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -163,8 +160,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); @@ -183,8 +179,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! @@ -206,8 +201,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -218,8 +212,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) @@ -238,8 +231,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -251,8 +243,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -264,7 +255,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -277,7 +268,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -286,7 +277,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) @@ -297,8 +288,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -314,8 +304,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -328,9 +317,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -343,9 +330,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -357,7 +342,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -366,8 +351,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -378,9 +362,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, @@ -393,9 +375,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -408,8 +388,7 @@ void testRemovePrompt() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -421,8 +400,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -442,8 +420,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -461,8 +438,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -474,8 +450,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -486,7 +461,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java index 0085f31ed..dbbf1a537 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/TestUtil.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.server; import java.io.IOException; diff --git a/mcp/pom.xml b/mcp/pom.xml index de4ee988a..1cf61c48f 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT mcp jar diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java index e33fafa6a..2e0b51748 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/LifecycleInitializer.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import java.time.Duration; @@ -286,7 +290,8 @@ public Mono withIntitialization(String actionName, Function { logger.warn("Failed to initialize", ex); - return Mono.error(new McpError("Client failed to initialize " + actionName)); + return Mono.error( + new McpError("Client failed to initialize " + actionName + " due to: " + ex.getMessage())); }) .flatMap(operation); }); diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 9e861deba..0f2ee19fa 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -1,9 +1,12 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.client; import java.time.Duration; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -174,7 +177,10 @@ public class McpAsyncClient { Map> requestHandlers = new HashMap<>(); // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, params -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, params -> { + logger.debug("Received ping: {}", LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)); + return Mono.just(Map.of()); + }); // Roots List Request Handler if (this.clientCapabilities.roots() != null) { @@ -266,13 +272,20 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, asyncProgressNotificationHandler(progressConsumersFinal)); - this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, - List.of(McpSchema.LATEST_PROTOCOL_VERSION), initializationTimeout, - ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers, - con -> con.contextWrite(ctx))); + this.initializer = new LifecycleInitializer(clientCapabilities, clientInfo, transport.protocolVersions(), + initializationTimeout, ctx -> new McpClientSession(requestTimeout, transport, requestHandlers, + notificationHandlers, con -> con.contextWrite(ctx))); this.transport.setExceptionHandler(this.initializer::handleException); } + /** + * Get the current initialization result. + * @return the initialization result. + */ + public McpSchema.InitializeResult getCurrentInitializationResult() { + return this.initializer.currentInitializationResult(); + } + /** * Get the server capabilities that define the supported features and functionality. * @return The server capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index 83c4900d1..33784adcd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -73,6 +73,14 @@ public class McpSyncClient implements AutoCloseable { this.delegate = delegate; } + /** + * Get the current initialization result. + * @return the initialization result. + */ + public McpSchema.InitializeResult getCurrentInitializationResult() { + return this.delegate.getCurrentInitializationResult(); + } + /** * Get the server capabilities that define the supported features and functionality. * @return The server capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java new file mode 100644 index 000000000..dee026d96 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/AsyncHttpRequestCustomizer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, in either SSE or + * Streamable HTTP transport. + *

+ * When used in a non-blocking context, implementations MUST be non-blocking. + * + * @author Daniel Garnier-Moiroux + */ +public interface AsyncHttpRequestCustomizer { + + Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + @Nullable String body); + + AsyncHttpRequestCustomizer NOOP = new Noop(); + + /** + * Wrap a sync implementation in an async wrapper. + *

+ * Do NOT wrap a blocking implementation for use in a non-blocking context. For a + * blocking implementation, consider using {@link Schedulers#boundedElastic()}. + */ + static AsyncHttpRequestCustomizer fromSync(SyncHttpRequestCustomizer customizer) { + return (builder, method, uri, body) -> Mono.fromSupplier(() -> { + customizer.customize(builder, method, uri, body); + return builder; + }); + } + + class Noop implements AsyncHttpRequestCustomizer { + + @Override + public Publisher customize(HttpRequest.Builder builder, String method, URI endpoint, + String body) { + return Mono.just(builder); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 8598e3164..473f71fbb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -1,6 +1,7 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024 - 2025 the original author or authors. */ + package io.modelcontextprotocol.client.transport; import java.io.IOException; @@ -9,6 +10,7 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -24,6 +26,7 @@ import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; @@ -61,6 +64,10 @@ */ public class HttpClientSseClientTransport implements McpClientTransport { + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2024_11_05; + + private static final String MCP_PROTOCOL_VERSION_HEADER_NAME = "MCP-Protocol-Version"; + private static final Logger logger = LoggerFactory.getLogger(HttpClientSseClientTransport.class); /** SSE event type for JSON-RPC messages */ @@ -102,6 +109,11 @@ public class HttpClientSseClientTransport implements McpClientTransport { */ protected final Sinks.One messageEndpointSink = Sinks.one(); + /** + * Customizer to modify requests before they are executed. + */ + private final AsyncHttpRequestCustomizer httpRequestCustomizer; + /** * Creates a new transport instance with default HTTP client and object mapper. * @param baseUri the base URI of the MCP server @@ -172,18 +184,43 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques * @param objectMapper the object mapper for JSON serialization/deserialization * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null */ + @Deprecated(forRemoval = true) HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { + this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP); + } + + /** + * Creates a new transport instance with custom HTTP client builder, object mapper, + * and headers. + * @param httpClient the HTTP client to use + * @param requestBuilder the HTTP request builder to use + * @param baseUri the base URI of the MCP server + * @param sseEndpoint the SSE endpoint path + * @param objectMapper the object mapper for JSON serialization/deserialization + * @param httpRequestCustomizer customizer for the requestBuilder before executing + * requests + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null + */ + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, + String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.hasText(baseUri, "baseUri must not be empty"); Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); + Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; + this.httpRequestCustomizer = httpRequestCustomizer; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } /** @@ -213,6 +250,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() .header("Content-Type", "application/json"); + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + /** * Creates a new builder instance. */ @@ -310,31 +349,67 @@ public Builder objectMapper(ObjectMapper objectMapper) { return this; } + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking + * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + /** * Builds a new {@link HttpClientSseClientTransport} instance. * @return a new transport instance */ public HttpClientSseClientTransport build() { return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, - objectMapper); + objectMapper, httpRequestCustomizer); } } @Override public Mono connect(Function, Mono> handler) { + var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - return Mono.create(sink -> { - - HttpRequest request = requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) + return Mono.defer(() -> { + var builder = requestBuilder.copy() + .uri(uri) .header("Accept", "text/event-stream") .header("Cache-Control", "no-cache") - .GET() - .build(); - + .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) + .GET(); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + }).flatMap(requestBuilder -> Mono.create(sink -> { Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) .exceptionallyCompose(e -> { sseSink.error(e); return CompletableFuture.failedFuture(e); @@ -366,10 +441,8 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { return Flux.just(message); } else { - logger.error("Received unrecognized SSE event type: {}", - responseEvent.sseEvent().event()); - sink.error(new McpError( - "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent()); + sink.success(); } } catch (IOException e) { @@ -399,7 +472,7 @@ else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { .subscribe(); this.sseSubscription.set(connection); - }); + })); } /** @@ -455,13 +528,16 @@ private Mono serializeMessage(final JSONRPCMessage message) { private Mono> sendHttpPost(final String endpoint, final String body) { final URI requestUri = Utils.resolveUri(baseUri, endpoint); - final HttpRequest request = this.requestBuilder.copy() - .uri(requestUri) - .POST(HttpRequest.BodyPublishers.ofString(body)) - .build(); - - // TODO: why discard the body? - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + return Mono.defer(() -> { + var builder = this.requestBuilder.copy() + .uri(requestUri) + .header(MCP_PROTOCOL_VERSION_HEADER_NAME, MCP_PROTOCOL_VERSION) + .POST(HttpRequest.BodyPublishers.ofString(body)); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body)); + }).flatMap(customizedBuilder -> { + var request = customizedBuilder.build(); + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java index 4cf1690ff..a9e5897b9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. */ package io.modelcontextprotocol.client.transport; @@ -28,12 +28,14 @@ import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; import io.modelcontextprotocol.spec.DefaultMcpTransportSession; import io.modelcontextprotocol.spec.DefaultMcpTransportStream; +import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpTransportSession; import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException; import io.modelcontextprotocol.spec.McpTransportStream; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.Utils; import reactor.core.Disposable; @@ -72,6 +74,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class); + private static final String MCP_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; + private static final String DEFAULT_ENDPOINT = "/mcp"; /** @@ -109,6 +113,8 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private final boolean resumableStreams; + private final AsyncHttpRequestCustomizer httpRequestCustomizer; + private final AtomicReference activeSession = new AtomicReference<>(); private final AtomicReference, Mono>> handler = new AtomicReference<>(); @@ -117,7 +123,7 @@ public class HttpClientStreamableHttpTransport implements McpClientTransport { private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, - boolean openConnectionOnStartup) { + boolean openConnectionOnStartup, AsyncHttpRequestCustomizer httpRequestCustomizer) { this.objectMapper = objectMapper; this.httpClient = httpClient; this.requestBuilder = requestBuilder; @@ -126,6 +132,12 @@ private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient this.resumableStreams = resumableStreams; this.openConnectionOnStartup = openConnectionOnStartup; this.activeSession.set(createTransportSession()); + this.httpRequestCustomizer = httpRequestCustomizer; + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); } public static Builder builder(String baseUri) { @@ -154,14 +166,20 @@ private DefaultMcpTransportSession createTransportSession() { } private Publisher createDelete(String sessionId) { - HttpRequest request = this.requestBuilder.copy() - .uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Cache-Control", "no-cache") - .header("mcp-session-id", sessionId) - .DELETE() - .build(); - - return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())).then(); + + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + return Mono.defer(() -> { + var builder = this.requestBuilder.copy() + .uri(uri) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.MCP_SESSION_ID, sessionId) + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .DELETE(); + return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null)); + }).flatMap(requestBuilder -> { + var request = requestBuilder.build(); + return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); + }).then(); } @Override @@ -208,96 +226,112 @@ private Mono reconnect(McpTransportStream stream) { final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); + var uri = Utils.resolveUri(this.baseUri, this.endpoint); - HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - - if (transportSession != null && transportSession.sessionId().isPresent()) { - requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); - } - - if (stream != null && stream.lastId().isPresent()) { - requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get()); - } - - HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", TEXT_EVENT_STREAM) - .header("Cache-Control", "no-cache") - .GET() - .build(); - - Disposable connection = Flux.create(sseSink -> this.httpClient - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) - .whenComplete((response, throwable) -> { - if (throwable != null) { - sseSink.error(throwable); - } - else { - logger.debug("SSE connection established successfully"); - } - })) - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) - .flatMap(responseEvent -> { - int statusCode = responseEvent.responseInfo().statusCode(); - - if (statusCode >= 200 && statusCode < 300) { + Disposable connection = Mono.defer(() -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { - try { - // We don't support batching ATM and probably won't since - // the - // next version considers removing it. - McpSchema.JSONRPCMessage message = McpSchema - .deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data()); - - Tuple2, Iterable> idWithMessages = Tuples - .of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message)); + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.MCP_SESSION_ID, + transportSession.sessionId().get()); + } - McpTransportStream sessionStream = stream != null ? stream - : new DefaultMcpTransportStream<>(this.resumableStreams, this::reconnect); - logger.debug("Connected stream {}", sessionStream.streamId()); + if (stream != null && stream.lastId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.LAST_EVENT_ID, stream.lastId().get()); + } - return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + var builder = requestBuilder.uri(uri) + .header("Accept", TEXT_EVENT_STREAM) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .GET(); + return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)); + }) + .flatMapMany( + requestBuilder -> Flux.create( + sseSink -> this.httpClient + .sendAsync(requestBuilder.build(), + responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, + sseSink)) + .whenComplete((response, throwable) -> { + if (throwable != null) { + sseSink.error(throwable); + } + else { + logger.debug("SSE connection established successfully"); + } + })) + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) + .flatMap(responseEvent -> { + int statusCode = responseEvent.responseInfo().statusCode(); + + if (statusCode >= 200 && statusCode < 300) { + + if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { + try { + // We don't support batching ATM and probably + // won't since the next version considers + // removing it. + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage( + this.objectMapper, responseEvent.sseEvent().data()); + + Tuple2, Iterable> idWithMessages = Tuples + .of(Optional.ofNullable(responseEvent.sseEvent().id()), + List.of(message)); + + McpTransportStream sessionStream = stream != null ? stream + : new DefaultMcpTransportStream<>(this.resumableStreams, + this::reconnect); + logger.debug("Connected stream {}", sessionStream.streamId()); + + return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); + + } + catch (IOException ioException) { + return Flux.error( + new McpError("Error parsing JSON-RPC message: " + + responseEvent.sseEvent().data())); + } + } + else { + logger.debug("Received SSE event with type: {}", responseEvent.sseEvent()); + return Flux.empty(); + } + } + else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed + logger + .debug("The server does not support SSE streams, using request-response mode."); + return Flux.empty(); + } + else if (statusCode == NOT_FOUND) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } + else if (statusCode == BAD_REQUEST) { + String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); + McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( + "Session not found for session ID: " + sessionIdRepresentation); + return Flux.error(exception); + } - } - catch (IOException ioException) { return Flux.error(new McpError( - "Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())); - } - } - } - else if (statusCode == METHOD_NOT_ALLOWED) { // NotAllowed - logger.debug("The server does not support SSE streams, using request-response mode."); - return Flux.empty(); - } - else if (statusCode == NOT_FOUND) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - else if (statusCode == BAD_REQUEST) { - String sessionIdRepresentation = sessionIdOrPlaceholder(transportSession); - McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException( - "Session not found for session ID: " + sessionIdRepresentation); - return Flux.error(exception); - } - - return Flux.error( - new McpError("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); - - }).flatMap(jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) - .onErrorMap(CompletionException.class, t -> t.getCause()) - .onErrorComplete(t -> { - this.handleException(t); - return true; - }) - .doFinally(s -> { - Disposable ref = disposableRef.getAndSet(null); - if (ref != null) { - transportSession.removeConnection(ref); - } - }) + "Received unrecognized SSE event type: " + responseEvent.sseEvent().event())); + }).flatMap( + jsonrpcMessage -> this.handler.get().apply(Mono.just(jsonrpcMessage))) + .onErrorMap(CompletionException.class, t -> t.getCause()) + .onErrorComplete(t -> { + this.handleException(t); + return true; + }) + .doFinally(s -> { + Disposable ref = disposableRef.getAndSet(null); + if (ref != null) { + transportSession.removeConnection(ref); + } + })) .contextWrite(ctx) .subscribe(); @@ -342,32 +376,36 @@ public String toString(McpSchema.JSONRPCMessage message) { } } - public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { - return Mono.create(messageSink -> { - logger.debug("Sending message {}", sendMessage); + public Mono sendMessage(McpSchema.JSONRPCMessage sentMessage) { + return Mono.create(deliveredSink -> { + logger.debug("Sending message {}", sentMessage); final AtomicReference disposableRef = new AtomicReference<>(); final McpTransportSession transportSession = this.activeSession.get(); - HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - - if (transportSession != null && transportSession.sessionId().isPresent()) { - requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get()); - } + var uri = Utils.resolveUri(this.baseUri, this.endpoint); + String jsonBody = this.toString(sentMessage); - String jsonBody = this.toString(sendMessage); + Disposable connection = Mono.defer(() -> { + HttpRequest.Builder requestBuilder = this.requestBuilder.copy(); - HttpRequest request = requestBuilder.uri(Utils.resolveUri(this.baseUri, this.endpoint)) - .header("Accept", TEXT_EVENT_STREAM + ", " + APPLICATION_JSON) - .header("Content-Type", APPLICATION_JSON) - .header("Cache-Control", "no-cache") - .POST(HttpRequest.BodyPublishers.ofString(jsonBody)) - .build(); + if (transportSession != null && transportSession.sessionId().isPresent()) { + requestBuilder = requestBuilder.header(HttpHeaders.MCP_SESSION_ID, + transportSession.sessionId().get()); + } - Disposable connection = Flux.create(responseEventSink -> { + var builder = requestBuilder.uri(uri) + .header("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM) + .header("Content-Type", APPLICATION_JSON) + .header("Cache-Control", "no-cache") + .header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION) + .POST(HttpRequest.BodyPublishers.ofString(jsonBody)); + return Mono.from(this.httpRequestCustomizer.customize(builder, "POST", uri, jsonBody)); + }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> { // Create the async request with proper body subscriber selection - Mono.fromFuture(this.httpClient.sendAsync(request, this.toSendMessageBodySubscriber(responseEventSink)) + Mono.fromFuture(this.httpClient + .sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber(responseEventSink)) .whenComplete((response, throwable) -> { if (throwable != null) { responseEventSink.error(throwable); @@ -377,13 +415,13 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { } })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe(); - }).flatMap(responseEvent -> { + })).flatMap(responseEvent -> { if (transportSession.markInitialized( responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) { // Once we have a session, we try to open an async stream for // the server to send notifications and requests out-of-band. - reconnect(null).contextWrite(messageSink.contextView()).subscribe(); + reconnect(null).contextWrite(deliveredSink.contextView()).subscribe(); } String sessionRepresentation = sessionIdOrPlaceholder(transportSession); @@ -400,16 +438,18 @@ public Mono sendMessage(McpSchema.JSONRPCMessage sendMessage) { if (contentType.isBlank()) { logger.debug("No content type returned for POST in session {}", sessionRepresentation); - // No content type means no response body, so we can just return + // No content type means no response body, so we can just + // return // an empty stream - messageSink.success(); + deliveredSink.success(); return Flux.empty(); } else if (contentType.contains(TEXT_EVENT_STREAM)) { return Flux.just(((ResponseSubscribers.SseResponseEvent) responseEvent).sseEvent()) .flatMap(sseEvent -> { try { - // We don't support batching ATM and probably won't + // We don't support batching ATM and probably + // won't // since the // next version considers removing it. McpSchema.JSONRPCMessage message = McpSchema @@ -423,7 +463,7 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { logger.debug("Connected stream {}", sessionStream.streamId()); - messageSink.success(); + deliveredSink.success(); return Flux.from(sessionStream.consumeSseStream(Flux.just(idWithMessages))); } @@ -434,12 +474,18 @@ else if (contentType.contains(TEXT_EVENT_STREAM)) { }); } else if (contentType.contains(APPLICATION_JSON)) { - messageSink.success(); + deliveredSink.success(); String data = ((ResponseSubscribers.AggregateResponseEvent) responseEvent).data(); + if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(data)) { + logger.warn("Notification: {} received non-compliant response: {}", sentMessage, data); + return Mono.empty(); + } + try { return Mono.just(McpSchema.deserializeJsonRpcMessage(objectMapper, data)); } catch (IOException e) { + // TODO: this should be a McpTransportError return Mono.error(e); } } @@ -473,7 +519,7 @@ else if (statusCode == BAD_REQUEST) { // handle the error first this.handleException(t); // inform the caller of sendMessage - messageSink.error(t); + deliveredSink.error(t); return true; }) .doFinally(s -> { @@ -483,7 +529,7 @@ else if (statusCode == BAD_REQUEST) { transportSession.removeConnection(ref); } }) - .contextWrite(messageSink.contextView()) + .contextWrite(deliveredSink.contextView()) .subscribe(); disposableRef.set(connection); @@ -521,6 +567,8 @@ public static class Builder { private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; + /** * Creates a new builder with the specified base URI. * @param baseUri the base URI of the MCP server @@ -623,6 +671,40 @@ public Builder openConnectionOnStartup(boolean openConnectionOnStartup) { return this; } + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking {@link SyncHttpRequestCustomizer} in a non-blocking + * context. Use {@link #asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer)} + * instead. + * @param syncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); + return this; + } + + /** + * Sets the customizer for {@link HttpRequest.Builder}, to modify requests before + * executing them. + *

+ * This overrides the customizer from + * {@link #httpRequestCustomizer(SyncHttpRequestCustomizer)}. + *

+ * Do NOT use a blocking implementation in a non-blocking context. + * @param asyncHttpRequestCustomizer the request customizer + * @return this builder + */ + public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { + this.httpRequestCustomizer = asyncHttpRequestCustomizer; + return this; + } + /** * Construct a fresh instance of {@link HttpClientStreamableHttpTransport} using * the current builder configuration. @@ -632,7 +714,7 @@ public HttpClientStreamableHttpTransport build() { ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); return new HttpClientStreamableHttpTransport(objectMapper, clientBuilder.build(), requestBuilder, baseUri, - endpoint, resumableStreams, openConnectionOnStartup); + endpoint, resumableStreams, openConnectionOnStartup, httpRequestCustomizer); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java index 26b0d13bd..4d9bdea5d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/ResponseSubscribers.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.client.transport; import java.net.http.HttpResponse; @@ -11,7 +12,10 @@ import org.reactivestreams.FlowAdapters; import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import io.modelcontextprotocol.spec.McpError; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.FluxSink; @@ -29,6 +33,8 @@ */ class ResponseSubscribers { + private static final Logger logger = LoggerFactory.getLogger(ResponseSubscribers.class); + record SseEvent(String id, String event, String data) { } @@ -135,6 +141,7 @@ protected void hookOnSubscribe(Subscription subscription) { @Override protected void hookOnNext(String line) { + if (line.isEmpty()) { // Empty line means end of event if (this.eventBuilder.length() > 0) { @@ -164,6 +171,18 @@ else if (line.startsWith("event:")) { this.currentEventType.set(matcher.group(1).trim()); } } + else if (line.startsWith(":")) { + // Ignore comment lines starting with ":" + // This is a no-op, just to skip comments + logger.debug("Ignoring comment line: {}", line); + } + else { + // If the response is not successful, emit an error + // TODO: This should be a McpTransportError + this.sink.error(new McpError( + "Invalid SSE response. Status code: " + this.responseInfo.statusCode() + " Line: " + line)); + + } } } @@ -228,10 +247,9 @@ protected void hookOnNext(String line) { @Override protected void hookOnComplete() { - if (this.eventBuilder.length() > 0) { - String data = this.eventBuilder.toString(); - this.sink.next(new AggregateResponseEvent(responseInfo, data)); - } + String data = this.eventBuilder.toString(); + this.sink.next(new AggregateResponseEvent(responseInfo, data)); + this.sink.complete(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java new file mode 100644 index 000000000..72b6e6c1b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/SyncHttpRequestCustomizer.java @@ -0,0 +1,21 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.net.URI; +import java.net.http.HttpRequest; +import reactor.util.annotation.Nullable; + +/** + * Customize {@link HttpRequest.Builder} before executing the request, either in SSE or + * Streamable HTTP transport. + * + * @author Daniel Garnier-Moiroux + */ +public interface SyncHttpRequestCustomizer { + + void customize(HttpRequest.Builder builder, String method, URI endpoint, @Nullable String body); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java new file mode 100644 index 000000000..2df3514b6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpStatelessServerHandler.java @@ -0,0 +1,62 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.Map; + +class DefaultMcpStatelessServerHandler implements McpStatelessServerHandler { + + private static final Logger logger = LoggerFactory.getLogger(DefaultMcpStatelessServerHandler.class); + + Map> requestHandlers; + + Map notificationHandlers; + + public DefaultMcpStatelessServerHandler(Map> requestHandlers, + Map notificationHandlers) { + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public Mono handleRequest(McpTransportContext transportContext, + McpSchema.JSONRPCRequest request) { + McpStatelessRequestHandler requestHandler = this.requestHandlers.get(request.method()); + if (requestHandler == null) { + return Mono.error(new McpError("Missing handler for request type: " + request.method())); + } + return requestHandler.handle(transportContext, request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(t -> { + McpSchema.JSONRPCResponse.JSONRPCError error; + if (t instanceof McpError mcpError && mcpError.getJsonRpcError() != null) { + error = mcpError.getJsonRpcError(); + } + else { + error = new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + t.getMessage(), null); + } + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, error)); + }); + } + + @Override + public Mono handleNotification(McpTransportContext transportContext, + McpSchema.JSONRPCNotification notification) { + McpStatelessNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.warn("Missing handler for notification type: {}", notification.method()); + return Mono.empty(); + } + return notificationHandler.handle(transportContext, notification.params()); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java new file mode 100644 index 000000000..9e18e189d --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/DefaultMcpTransportContext.java @@ -0,0 +1,49 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Default implementation for {@link McpTransportContext} which uses a Thread-safe map. + * Objects of this kind are mutable. + * + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpTransportContext implements McpTransportContext { + + private final Map storage; + + /** + * Create an empty instance. + */ + public DefaultMcpTransportContext() { + this.storage = new ConcurrentHashMap<>(); + } + + DefaultMcpTransportContext(Map storage) { + this.storage = storage; + } + + @Override + public Object get(String key) { + return this.storage.get(key); + } + + @Override + public void put(String key, Object value) { + this.storage.put(key, value); + } + + /** + * Allows copying the contents. + * @return new instance with the copy of the underlying map + */ + public McpTransportContext copy() { + return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 7131b10fa..5b5e838f3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -15,6 +15,9 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,7 +89,7 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final McpServerTransportProvider mcpTransportProvider; + private final McpServerTransportProviderBase mcpTransportProvider; private final ObjectMapper objectMapper; @@ -112,7 +115,7 @@ public class McpAsyncServer { private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); - private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + private List protocolVersions; private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); @@ -139,7 +142,60 @@ public class McpAsyncServer { this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + this.protocolVersions = mcpTransportProvider.protocolVersions(); + + mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + this.protocolVersions = mcpTransportProvider.protocolVersions(); + + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + private Map prepareNotificationHandlers(McpServerFeatures.Async features) { + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + return notificationHandlers; + } + + private Map> prepareRequestHandlers() { + Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods @@ -174,25 +230,7 @@ public class McpAsyncServer { if (this.serverCapabilities.completions() != null) { requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + return requestHandlers; } // --------------------------------------- @@ -258,11 +296,11 @@ public void close() { this.mcpTransportProvider.close(); } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + private McpNotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return (exchange, params) -> exchange.listRoots() .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .flatMap(consumer -> Mono.defer(() -> consumer.apply(exchange, listRootsResult.roots()))) .onErrorResume(error -> { logger.error("Error handling roots list change notification", error); return Mono.empty(); @@ -450,7 +488,7 @@ public Mono notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } - private McpServerSession.RequestHandler toolsListRequestHandler() { + private McpRequestHandler toolsListRequestHandler() { return (exchange, params) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); @@ -458,7 +496,7 @@ private McpServerSession.RequestHandler toolsListRequ }; } - private McpServerSession.RequestHandler toolsCallRequestHandler() { + private McpRequestHandler toolsCallRequestHandler() { return (exchange, params) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { @@ -472,7 +510,7 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.callHandler().apply(exchange, callToolRequest)) + return toolSpecification.map(tool -> Mono.defer(() -> tool.callHandler().apply(exchange, callToolRequest))) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -551,7 +589,7 @@ public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification); } - private McpServerSession.RequestHandler resourcesListRequestHandler() { + private McpRequestHandler resourcesListRequestHandler() { return (exchange, params) -> { var resourceList = this.resources.values() .stream() @@ -561,7 +599,7 @@ private McpServerSession.RequestHandler resources }; } - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + private McpRequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); @@ -585,7 +623,7 @@ private List getResourceTemplates() { return list; } - private McpServerSession.RequestHandler resourcesReadRequestHandler() { + private McpRequestHandler resourcesReadRequestHandler() { return (exchange, params) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { @@ -600,7 +638,7 @@ private McpServerSession.RequestHandler resourcesR .findFirst() .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - return specification.readHandler().apply(exchange, resourceRequest); + return Mono.defer(() -> specification.readHandler().apply(exchange, resourceRequest)); }; } @@ -678,7 +716,7 @@ public Mono notifyPromptsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } - private McpServerSession.RequestHandler promptsListRequestHandler() { + private McpRequestHandler promptsListRequestHandler() { return (exchange, params) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, @@ -694,7 +732,7 @@ private McpServerSession.RequestHandler promptsList }; } - private McpServerSession.RequestHandler promptsGetRequestHandler() { + private McpRequestHandler promptsGetRequestHandler() { return (exchange, params) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { @@ -706,7 +744,7 @@ private McpServerSession.RequestHandler promptsGetReq return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); } - return specification.promptHandler().apply(exchange, promptRequest); + return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); }; } @@ -740,7 +778,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN loggingMessageNotification); } - private McpServerSession.RequestHandler setLoggerRequestHandler() { + private McpRequestHandler setLoggerRequestHandler() { return (exchange, params) -> { return Mono.defer(() -> { @@ -759,7 +797,7 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { }; } - private McpServerSession.RequestHandler completionCompleteRequestHandler() { + private McpRequestHandler completionCompleteRequestHandler() { return (exchange, params) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); @@ -811,7 +849,7 @@ private McpServerSession.RequestHandler completionComp return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); } - return specification.completionHandler().apply(exchange, request); + return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); }; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index c0923e10e..61d60bacc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -9,10 +9,11 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -25,13 +26,15 @@ */ public class McpAsyncServerExchange { - private final McpServerSession session; + private final String sessionId; + + private final McpLoggableSession session; private final McpSchema.ClientCapabilities clientCapabilities; private final McpSchema.Implementation clientInfo; - private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; + private final McpTransportContext transportContext; private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { }; @@ -51,12 +54,39 @@ public class McpAsyncServerExchange { * @param clientCapabilities The client capabilities that define the supported * features and functionality. * @param clientInfo The client implementation information. + * @deprecated Use + * {@link #McpAsyncServerExchange(String, McpLoggableSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} */ - public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + @Deprecated + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.sessionId = null; + if (!(session instanceof McpLoggableSession)) { + throw new IllegalArgumentException("Expecting session to be a McpLoggableSession instance"); + } + this.session = (McpLoggableSession) session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.transportContext = McpTransportContext.EMPTY; + } + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param clientInfo The client implementation information. + * @param transportContext context associated with the client as extracted from the + * transport + */ + public McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext) { + this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.transportContext = transportContext; } /** @@ -75,6 +105,24 @@ public McpSchema.Implementation getClientInfo() { return this.clientInfo; } + /** + * Provides the {@link McpTransportContext} associated with the transport layer. For + * HTTP transports it can contain the metadata associated with the HTTP request that + * triggered the processing. + * @return the transport context object + */ + public McpTransportContext transportContext() { + return this.transportContext; + } + + /** + * Provides the Session ID. + * @return session ID string + */ + public String sessionId() { + return this.sessionId; + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM @@ -170,7 +218,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } return Mono.defer(() -> { - if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + if (this.session.isNotificationForLevelAllowed(loggingMessageNotification.level())) { return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); } return Mono.empty(); @@ -205,11 +253,7 @@ public Mono ping() { */ void setMinLoggingLevel(LoggingLevel minLoggingLevel) { Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); - this.minLoggingLevel = minLoggingLevel; - } - - private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { - return loggingLevel.level() >= this.minLoggingLevel.level(); + this.session.setMinLoggingLevel(minLoggingLevel); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java new file mode 100644 index 000000000..13ff45a54 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Request handler for the initialization request. + */ +public interface McpInitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java new file mode 100644 index 000000000..6b1061c03 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated notifications. + */ +public interface McpNotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java new file mode 100644 index 000000000..c9d70ad04 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ +public interface McpRequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d4b8addf4..f5dfffffb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -21,6 +21,8 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; @@ -131,6 +133,8 @@ */ public interface McpServer { + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + /** * Starts building a synchronous MCP server that provides blocking operations. * Synchronous servers block the current Thread's execution upon each request before @@ -139,8 +143,8 @@ public interface McpServer { * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link SyncSpecification} for configuring the server. */ - static SyncSpecification sync(McpServerTransportProvider transportProvider) { - return new SyncSpecification(transportProvider); + static SingleSessionSyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SingleSessionSyncSpecification(transportProvider); } /** @@ -151,31 +155,129 @@ static SyncSpecification sync(McpServerTransportProvider transportProvider) { * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - static AsyncSpecification async(McpServerTransportProvider transportProvider) { - return new AsyncSpecification(transportProvider); + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new SingleSessionAsyncSpecification(transportProvider); } /** - * Asynchronous server specification. + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StreamableSyncSpecification sync(McpStreamableServerTransportProvider transportProvider) { + return new StreamableSyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static AsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { + return new StreamableServerAsyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transport The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static StatelessAsyncSpecification async(McpStatelessServerTransport transport) { + return new StatelessAsyncSpecification(transport); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transport The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. */ - class AsyncSpecification { + static StatelessSyncSpecification sync(McpStatelessServerTransport transport) { + return new StatelessSyncSpecification(transport); + } - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); + class SingleSessionAsyncSpecification extends AsyncSpecification { private final McpServerTransportProvider transportProvider; - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private SingleSessionAsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + class StreamableServerAsyncSpecification extends AsyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider transportProvider) { + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + /** + * Asynchronous server specification. + */ + abstract class AsyncSpecification> { + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - private ObjectMapper objectMapper; + ObjectMapper objectMapper; - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - private McpSchema.ServerCapabilities serverCapabilities; + McpSchema.ServerCapabilities serverCapabilities; - private JsonSchemaValidator jsonSchemaValidator; + JsonSchemaValidator jsonSchemaValidator; - private String instructions; + String instructions; /** * The Model Context Protocol (MCP) allows servers to expose tools that can be @@ -184,7 +286,7 @@ class AsyncSpecification { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -193,9 +295,9 @@ class AsyncSpecification { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + final Map resources = new HashMap<>(); - private final List resourceTemplates = new ArrayList<>(); + final List resourceTemplates = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -204,18 +306,15 @@ class AsyncSpecification { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + final Map prompts = new HashMap<>(); - private final Map completions = new HashMap<>(); + final Map completions = new HashMap<>(); - private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofHours(10); // Default timeout - private AsyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - } + public abstract McpAsyncServer build(); /** * Sets the URI template manager factory to use for creating URI templates. This @@ -224,7 +323,7 @@ private AsyncSpecification(McpServerTransportProvider transportProvider) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; return this; @@ -239,7 +338,7 @@ public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory * @return This builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public AsyncSpecification requestTimeout(Duration requestTimeout) { + public AsyncSpecification requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; return this; @@ -254,7 +353,7 @@ public AsyncSpecification requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -270,7 +369,7 @@ public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public AsyncSpecification serverInfo(String name, String version) { + public AsyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -284,7 +383,7 @@ public AsyncSpecification serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public AsyncSpecification instructions(String instructions) { + public AsyncSpecification instructions(String instructions) { this.instructions = instructions; return this; } @@ -303,7 +402,7 @@ public AsyncSpecification instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; @@ -334,7 +433,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * calls that require a request object. */ @Deprecated - public AsyncSpecification tool(McpSchema.Tool tool, + public AsyncSpecification tool(McpSchema.Tool tool, BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -358,7 +457,7 @@ public AsyncSpecification tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public AsyncSpecification toolCall(McpSchema.Tool tool, + public AsyncSpecification toolCall(McpSchema.Tool tool, BiFunction> callHandler) { Assert.notNull(tool, "Tool must not be null"); @@ -381,7 +480,7 @@ public AsyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.AsyncToolSpecification...) */ - public AsyncSpecification tools(List toolSpecifications) { + public AsyncSpecification tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -408,7 +507,7 @@ public AsyncSpecification tools(List t * @return This builder instance for method chaining * @throws IllegalArgumentException if toolSpecifications is null */ - public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { @@ -434,7 +533,7 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpecification resources( + public AsyncSpecification resources( Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); @@ -450,7 +549,8 @@ public AsyncSpecification resources( * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpecification resources(List resourceSpecifications) { + public AsyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -475,7 +575,7 @@ public AsyncSpecification resources(List resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -500,7 +600,7 @@ public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public AsyncSpecification resourceTemplates(List resourceTemplates) { + public AsyncSpecification resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; @@ -514,7 +614,7 @@ public AsyncSpecification resourceTemplates(List resourceTempl * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); @@ -539,7 +639,7 @@ public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplate * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public AsyncSpecification prompts(Map prompts) { + public AsyncSpecification prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; @@ -553,7 +653,7 @@ public AsyncSpecification prompts(Map prompts) { + public AsyncSpecification prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -577,7 +677,7 @@ public AsyncSpecification prompts(List prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -592,7 +692,7 @@ public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... * @return This builder instance for method chaining * @throws IllegalArgumentException if completions is null */ - public AsyncSpecification completions(List completions) { + public AsyncSpecification completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -607,7 +707,7 @@ public AsyncSpecification completions(List completions(McpServerFeatures.AsyncCompletionSpecification... completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -625,7 +725,7 @@ public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecifica * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public AsyncSpecification rootsChangeHandler( + public AsyncSpecification rootsChangeHandler( BiFunction, Mono> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); @@ -641,7 +741,7 @@ public AsyncSpecification rootsChangeHandler( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandler(BiFunction) */ - public AsyncSpecification rootsChangeHandlers( + public AsyncSpecification rootsChangeHandlers( List, Mono>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); @@ -657,7 +757,7 @@ public AsyncSpecification rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public AsyncSpecification rootsChangeHandlers( + public AsyncSpecification rootsChangeHandlers( @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(Arrays.asList(handlers)); @@ -669,7 +769,7 @@ public AsyncSpecification rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + public AsyncSpecification objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; return this; @@ -683,26 +783,76 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { * @return This builder instance for method chaining * @throws IllegalArgumentException if jsonSchemaValidator is null */ - public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; return this; } + } + + class SingleSessionSyncSpecification extends SyncSpecification { + + private final McpServerTransportProvider transportProvider; + + private SingleSessionSyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + /** - * Builds an asynchronous MCP server that provides non-blocking operations. - * @return A new instance of {@link McpAsyncServer} configured with this builder's + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's * settings. */ - public McpAsyncServer build() { - var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + + return new McpSyncServer(asyncServer, this.immediateExecution); + } + + } + + class StreamableSyncSpecification extends SyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + private StreamableSyncSpecification(McpStreamableServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); + + return new McpSyncServer(asyncServer, this.immediateExecution); } } @@ -710,22 +860,17 @@ public McpAsyncServer build() { /** * Synchronous server specification. */ - class SyncSpecification { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + abstract class SyncSpecification> { - private final McpServerTransportProvider transportProvider; + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - private ObjectMapper objectMapper; + ObjectMapper objectMapper; - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - private McpSchema.ServerCapabilities serverCapabilities; + McpSchema.ServerCapabilities serverCapabilities; - private String instructions; + String instructions; /** * The Model Context Protocol (MCP) allows servers to expose tools that can be @@ -734,7 +879,7 @@ class SyncSpecification { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -743,11 +888,11 @@ class SyncSpecification { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + final Map resources = new HashMap<>(); - private final List resourceTemplates = new ArrayList<>(); + final List resourceTemplates = new ArrayList<>(); - private JsonSchemaValidator jsonSchemaValidator; + JsonSchemaValidator jsonSchemaValidator; /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -756,20 +901,17 @@ class SyncSpecification { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + final Map prompts = new HashMap<>(); - private final Map completions = new HashMap<>(); + final Map completions = new HashMap<>(); - private final List>> rootsChangeHandlers = new ArrayList<>(); + final List>> rootsChangeHandlers = new ArrayList<>(); - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout - private boolean immediateExecution = false; + boolean immediateExecution = false; - private SyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - } + public abstract McpSyncServer build(); /** * Sets the URI template manager factory to use for creating URI templates. This @@ -778,7 +920,7 @@ private SyncSpecification(McpServerTransportProvider transportProvider) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; return this; @@ -790,10 +932,10 @@ public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory * resource access, and prompt operations. * @param requestTimeout The duration to wait before timing out requests. Must not * be null. - * @return This builder instance for method chaining + * @return this builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public SyncSpecification requestTimeout(Duration requestTimeout) { + public SyncSpecification requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; return this; @@ -808,7 +950,7 @@ public SyncSpecification requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -824,7 +966,7 @@ public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public SyncSpecification serverInfo(String name, String version) { + public SyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -838,7 +980,7 @@ public SyncSpecification serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public SyncSpecification instructions(String instructions) { + public SyncSpecification instructions(String instructions) { this.instructions = instructions; return this; } @@ -857,7 +999,7 @@ public SyncSpecification instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; @@ -887,7 +1029,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * calls that require a request object. */ @Deprecated - public SyncSpecification tool(McpSchema.Tool tool, + public SyncSpecification tool(McpSchema.Tool tool, BiFunction, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -911,7 +1053,7 @@ public SyncSpecification tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public SyncSpecification toolCall(McpSchema.Tool tool, + public SyncSpecification toolCall(McpSchema.Tool tool, BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -932,7 +1074,7 @@ public SyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.SyncToolSpecification...) */ - public SyncSpecification tools(List toolSpecifications) { + public SyncSpecification tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -961,7 +1103,7 @@ public SyncSpecification tools(List too * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(List) */ - public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { @@ -987,7 +1129,7 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpecification resources( + public SyncSpecification resources( Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); @@ -1003,7 +1145,8 @@ public SyncSpecification resources( * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpecification resources(List resourceSpecifications) { + public SyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -1028,7 +1171,7 @@ public SyncSpecification resources(List resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -1053,7 +1196,7 @@ public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification.. * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public SyncSpecification resourceTemplates(List resourceTemplates) { + public SyncSpecification resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; @@ -1067,7 +1210,7 @@ public SyncSpecification resourceTemplates(List resourceTempla * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ - public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); @@ -1093,7 +1236,7 @@ public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpecification prompts(Map prompts) { + public SyncSpecification prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; @@ -1107,7 +1250,7 @@ public SyncSpecification prompts(Map prompts) { + public SyncSpecification prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -1131,7 +1274,7 @@ public SyncSpecification prompts(List * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -1147,7 +1290,7 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr * @throws IllegalArgumentException if completions is null * @see #completions(McpServerFeatures.SyncCompletionSpecification...) */ - public SyncSpecification completions(List completions) { + public SyncSpecification completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.SyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -1162,7 +1305,7 @@ public SyncSpecification completions(List completions(McpServerFeatures.SyncCompletionSpecification... completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.SyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -1180,7 +1323,8 @@ public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecificati * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public SyncSpecification rootsChangeHandler(BiConsumer> handler) { + public SyncSpecification rootsChangeHandler( + BiConsumer> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); return this; @@ -1195,7 +1339,7 @@ public SyncSpecification rootsChangeHandler(BiConsumer rootsChangeHandlers( List>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); @@ -1211,7 +1355,7 @@ public SyncSpecification rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public SyncSpecification rootsChangeHandlers( + public SyncSpecification rootsChangeHandlers( BiConsumer>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(List.of(handlers)); @@ -1223,13 +1367,13 @@ public SyncSpecification rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public SyncSpecification objectMapper(ObjectMapper objectMapper) { + public SyncSpecification objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; return this; } - public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; return this; @@ -1246,30 +1390,966 @@ public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValid * @return This builder instance for method chaining. * */ - public SyncSpecification immediateExecution(boolean immediateExecution) { + public SyncSpecification immediateExecution(boolean immediateExecution) { this.immediateExecution = immediateExecution; return this; } + } + + class StatelessAsyncSpecification { + + private final McpStatelessServerTransport transport; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + ObjectMapper objectMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + /** - * Builds a synchronous MCP server that provides blocking operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings. + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. */ - public McpSyncServer build() { - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, - this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); + final List tools = new ArrayList<>(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); - return new McpSyncServer(asyncServer, this.immediateExecution); + final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessAsyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessAsyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessAsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessAsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessAsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessAsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.AsyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.AsyncToolSpecification...) + */ + public StatelessAsyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessAsyncSpecification tools( + McpStatelessServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessAsyncSpecification resources( + McpStatelessServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

+ * Example usage:

{@code
+		 * .resourceTemplates(
+		 *     new ResourceTemplate("file://{path}", "Access files by path"),
+		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
+		 * )
+		 * }
+ * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public StatelessAsyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessAsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.AsyncPromptSpecification...) + */ + public StatelessAsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts(McpStatelessServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + McpStatelessServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public StatelessAsyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + public McpStatelessAsyncServer build() { + var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpStatelessAsyncServer(this.transport, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + class StatelessSyncSpecification { + + private final McpStatelessServerTransport transport; + + boolean immediateExecution = false; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + ObjectMapper objectMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessSyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessSyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessSyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessSyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessSyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessSyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessSyncSpecification toolCall(McpSchema.Tool tool, + BiFunction callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.SyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.SyncToolSpecification...) + */ + public StatelessSyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessSyncSpecification tools( + McpStatelessServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.SyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessSyncSpecification resources( + McpStatelessServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

+ * Example usage:

{@code
+		 * .resourceTemplates(
+		 *     new ResourceTemplate("file://{path}", "Access files by path"),
+		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
+		 * )
+		 * }
+ * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public StatelessSyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessSyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.SyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.SyncPromptSpecification...) + */ + public StatelessSyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.SyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts(McpStatelessServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + McpStatelessServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public StatelessSyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessSyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpStatelessAsyncServer}. Defaults to false, which does blocking code + * offloading to prevent accidental blocking of the non-blocking transport. + *

+ * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public StatelessSyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + public McpStatelessSyncServer build() { + /* + * McpServerFeatures.Sync syncFeatures = new + * McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + * this.tools, this.resources, this.resourceTemplates, this.prompts, + * this.completions, this.rootsChangeHandlers, this.instructions); + * McpServerFeatures.Async asyncFeatures = + * McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + * var mapper = this.objectMapper != null ? this.objectMapper : new + * ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null + * ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); + * + * var asyncServer = new McpAsyncServer(this.transportProvider, mapper, + * asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, + * jsonSchemaValidator); + * + * return new McpSyncServer(asyncServer, this.immediateExecution); + */ + var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java new file mode 100644 index 000000000..41e0e9588 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -0,0 +1,673 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +/** + * A stateless MCP server implementation for use with Streamable HTTP transport types. It + * allows simple horizontal scalability since it does not maintain a session and does not + * require initialization. Each instance of the server can be reached with no prior + * knowledge and can serve the clients with the capabilities it supports. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessAsyncServer.class); + + private final McpStatelessServerTransport mcpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions; + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + private final JsonSchemaValidator jsonSchemaValidator; + + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, ObjectMapper objectMapper, + McpStatelessServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransport; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (ctx, params) -> Mono.just(Map.of())); + + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + this.protocolVersions = new ArrayList<>(mcpTransport.protocolVersions()); + + McpStatelessServerHandler handler = new DefaultMcpStatelessServerHandler(requestHandlers, Map.of()); + mcpTransport.setMcpHandler(handler); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private McpStatelessRequestHandler asyncInitializeRequestHandler() { + return (ctx, req) -> Mono.defer(() -> { + McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, + McpSchema.InitializeRequest.class); + + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + private static List withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, List tools) { + + if (Utils.isEmpty(tools)) { + return tools; + } + + return tools.stream().map(tool -> withStructuredOutputHandling(jsonSchemaValidator, tool)).toList(); + } + + private static McpStatelessServerFeatures.AsyncToolSpecification withStructuredOutputHandling( + JsonSchemaValidator jsonSchemaValidator, + McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + + if (toolSpecification.callHandler() instanceof StructuredOutputCallToolHandler) { + // If the tool is already wrapped, return it as is + return toolSpecification; + } + + if (toolSpecification.tool().outputSchema() == null) { + // If the tool does not have an output schema, return it as is + return toolSpecification; + } + + return new McpStatelessServerFeatures.AsyncToolSpecification(toolSpecification.tool(), + new StructuredOutputCallToolHandler(jsonSchemaValidator, toolSpecification.tool().outputSchema(), + toolSpecification.callHandler())); + } + + private static class StructuredOutputCallToolHandler + implements BiFunction> { + + private final BiFunction> delegateHandler; + + private final JsonSchemaValidator jsonSchemaValidator; + + private final Map outputSchema; + + public StructuredOutputCallToolHandler(JsonSchemaValidator jsonSchemaValidator, + Map outputSchema, + BiFunction> delegateHandler) { + + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + Assert.notNull(delegateHandler, "Delegate call tool result handler must not be null"); + + this.delegateHandler = delegateHandler; + this.outputSchema = outputSchema; + this.jsonSchemaValidator = jsonSchemaValidator; + } + + @Override + public Mono apply(McpTransportContext transportContext, McpSchema.CallToolRequest request) { + + return this.delegateHandler.apply(transportContext, request).map(result -> { + + if (outputSchema == null) { + if (result.structuredContent() != null) { + logger.warn( + "Tool call with no outputSchema is not expected to have a result with structured content, but got: {}", + result.structuredContent()); + } + // Pass through. No validation is required if no output schema is + // provided. + return result; + } + + // If an output schema is provided, servers MUST provide structured + // results that conform to this schema. + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + if (result.structuredContent() == null) { + logger.warn( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + return new CallToolResult( + "Response missing structured content which is expected when calling tool with non-empty outputSchema", + true); + } + + // Validate the result against the output schema + var validation = this.jsonSchemaValidator.validate(outputSchema, result.structuredContent()); + + if (!validation.valid()) { + logger.warn("Tool call result validation failed: {}", validation.errorMessage()); + return new CallToolResult(validation.errorMessage(), true); + } + + if (Utils.isEmpty(result.content())) { + // For backwards compatibility, a tool that returns structured + // content SHOULD also return functionally equivalent unstructured + // content. (For example, serialized JSON can be returned in a + // TextContent block.) + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + + return new CallToolResult(List.of(new McpSchema.TextContent(validation.jsonStructuredOutput())), + result.isError(), result.structuredContent()); + } + + return result; + }); + } + + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.callHandler() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + var wrappedToolSpecification = withStructuredOutputHandling(this.jsonSchemaValidator, toolSpecification); + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(wrappedToolSpecification.tool().name()))) { + return Mono.error( + new McpError("Tool with name '" + wrappedToolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(wrappedToolSpecification); + logger.debug("Added tool handler: {}", wrappedToolSpecification.tool().name()); + + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + private McpStatelessRequestHandler toolsListRequestHandler() { + return (ctx, params) -> { + List tools = this.tools.stream() + .map(McpStatelessServerFeatures.AsyncToolSpecification::tool) + .toList(); + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpStatelessRequestHandler toolsCallRequestHandler() { + return (ctx, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.callHandler().apply(ctx, callToolRequest)) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + private McpStatelessRequestHandler resourcesListRequestHandler() { + return (ctx, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpStatelessRequestHandler resourceTemplateListRequestHandler() { + return (ctx, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new ResourceTemplate(resource.uri(), resource.name(), resource.title(), + resource.description(), resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; + } + + private McpStatelessRequestHandler resourcesReadRequestHandler() { + return (ctx, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + + McpStatelessServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(ctx, resourceRequest); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error( + new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + private McpStatelessRequestHandler promptsListRequestHandler() { + return (ctx, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpStatelessRequestHandler promptsGetRequestHandler() { + return (ctx, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(ctx, promptRequest); + }; + } + + private McpStatelessRequestHandler completionCompleteRequestHandler() { + return (ctx, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); + } + + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts + .get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + } + if (promptSpec.prompt().arguments().stream().noneMatch(arg -> arg.name().equals(argumentName))) { + + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } + + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources + .get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); + } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + + } + + McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } + + return specification.completionHandler().apply(ctx, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java new file mode 100644 index 000000000..6db79a62c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessNotificationHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * Handler for MCP notifications in a stateless server. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessNotificationHandler { + + /** + * Handle to notification and complete once done. + * @param transportContext {@link McpTransportContext} associated with the transport + * @param params the payload of the MCP notification + * @return Mono which completes once the processing is done + */ + Mono handle(McpTransportContext transportContext, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java new file mode 100644 index 000000000..e5c9e7c09 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessRequestHandler.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * Handler for MCP requests in a stateless server. + * + * @param type of the MCP response + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessRequestHandler { + + /** + * Handle the request and complete with a result. + * @param transportContext {@link McpTransportContext} associated with the transport + * @param params the payload of the MCP request + * @return Mono which completes with the response object + */ + Mono handle(McpTransportContext transportContext, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java new file mode 100644 index 000000000..8be59a779 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -0,0 +1,482 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * MCP stateless server features specification that a particular server can choose to + * support. + * + * @author Dariusz Jędrzejczyk + * @author Christian Tzolov + */ +public class McpStatelessServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); + this.instructions = instructions; + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec, boolean immediateExecution) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); + }); + + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); + }); + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, + syncSpec.resourceTemplates(), prompts, completions, syncSpec.instructions()); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); + this.instructions = instructions; + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning the result. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction> callHandler) { + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { + + // FIXME: This is temporary, proper validation should be implemented + if (syncToolSpec == null) { + return null; + } + + BiFunction> callHandler = (ctx, + req) -> { + var toolResult = Mono.fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); + } + + /** + * Builder for creating AsyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction> callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the AsyncToolSpecification instance. + * @return a new AsyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public AsyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Call handler function must not be null"); + + return new AsyncToolSpecification(tool, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *

    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), (ctx, req) -> { + var promptResult = Mono.fromCallable(() -> prompt.promptHandler().apply(ctx, req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *
    + *
  • Customizable response generation logic + *
  • Parameter-driven template expansion + *
  • Dynamic interaction with connected clients + *
+ * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The function's argument is a + * {@link McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), (ctx, req) -> { + var completionResult = Mono.fromCallable(() -> completion.completionHandler().apply(ctx, req)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning results. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction callHandler) { + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating SyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the SyncToolSpecification instance. + * @return a new SyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public SyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "CallTool function must not be null"); + + return new SyncToolSpecification(tool, callHandler); + } + + } + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The argument is a {@link McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java new file mode 100644 index 000000000..7c4e23cfc --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerHandler.java @@ -0,0 +1,37 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Handler for MCP requests and notifications in a Stateless Streamable HTTP Server + * context. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStatelessServerHandler { + + /** + * Handle the request using user-provided feature implementations. + * @param transportContext {@link McpTransportContext} carrying transport layer + * metadata + * @param request the request JSON object + * @return Mono containing the JSON response + */ + Mono handleRequest(McpTransportContext transportContext, + McpSchema.JSONRPCRequest request); + + /** + * Handle the notification. + * @param transportContext {@link McpTransportContext} carrying transport layer + * metadata + * @param notification the notification JSON object + * @return Mono that completes once handling is finished + */ + Mono handleNotification(McpTransportContext transportContext, McpSchema.JSONRPCNotification notification); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java new file mode 100644 index 000000000..0151a754b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -0,0 +1,132 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.List; + +/** + * A stateless MCP server implementation for use with Streamable HTTP transport types. It + * allows simple horizontal scalability since it does not maintain a session and does not + * require initialization. Each instance of the server can be reached with no prior + * knowledge and can serve the clients with the capabilities it supports. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessSyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessSyncServer.class); + + private final McpStatelessAsyncServer asyncServer; + + private final boolean immediateExecution; + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer, boolean immediateExecution) { + this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.asyncServer.getServerCapabilities(); + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.asyncServer.getServerInfo(); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.asyncServer.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.asyncServer.close(); + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + */ + public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecification) { + this.asyncServer + .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, + this.immediateExecution)) + .block(); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + */ + public void removeTool(String toolName) { + this.asyncServer.removeTool(toolName).block(); + } + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + */ + public void addResource(McpStatelessServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + */ + public void removeResource(String resourceUri) { + this.asyncServer.removeResource(resourceUri).block(); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + */ + public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer + .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, + this.immediateExecution)) + .block(); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + */ + public void removePrompt(String promptName) { + this.asyncServer.removePrompt(promptName).block(); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.asyncServer.setProtocolVersions(protocolVersions); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java index 38f5128e4..5adda1a74 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServer.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.server.McpServerFeatures.AsyncToolSpecification; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.util.Assert; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index dad1e4c19..5f22df5e9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -27,6 +27,14 @@ public McpSyncServerExchange(McpAsyncServerExchange exchange) { this.exchange = exchange; } + /** + * Provides the Session ID + * @return session ID + */ + public String sessionId() { + return this.exchange.sessionId(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities @@ -43,6 +51,16 @@ public McpSchema.Implementation getClientInfo() { return this.exchange.getClientInfo(); } + /** + * Provides the {@link McpTransportContext} associated with the transport layer. For + * HTTP transports it can contain the metadata associated with the HTTP request that + * triggered the processing. + * @return the transport context object + */ + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java new file mode 100644 index 000000000..1cd540f72 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContext.java @@ -0,0 +1,50 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.util.Collections; + +/** + * Context associated with the transport layer. It allows to add transport-level metadata + * for use further down the line. Specifically, it can be beneficial to extract HTTP + * request metadata for use in MCP feature implementations. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportContext { + + /** + * Key for use in Reactor Context to transport the context to user land. + */ + String KEY = "MCP_TRANSPORT_CONTEXT"; + + /** + * An empty, unmodifiable context. + */ + @SuppressWarnings("unchecked") + McpTransportContext EMPTY = new DefaultMcpTransportContext(Collections.EMPTY_MAP); + + /** + * Extract a value from the context. + * @param key the key under the data is expected + * @return the associated value or {@code null} if missing. + */ + Object get(String key); + + /** + * Inserts a value for a given key. + * @param key a String representing the key + * @param value the value to store + */ + void put(String key, Object value); + + /** + * Copies the contents of the context to allow further modifications without affecting + * the initial object. + * @return a new instance with the underlying storage copied. + */ + McpTransportContext copy(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java new file mode 100644 index 000000000..97fcecf0d --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpTransportContextExtractor.java @@ -0,0 +1,28 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +/** + * The contract for extracting metadata from a generic transport request of type + * {@link T}. + * + * @param transport-specific representation of the request which allows extracting + * metadata for use in the MCP features implementations. + * @author Dariusz Jędrzejczyk + */ +public interface McpTransportContextExtractor { + + /** + * Given an empty context, provides the means to fill it with transport-specific + * metadata extracted from the request. + * @param request the generic representation for the request in the context of a + * specific transport implementation + * @param transportContext the mutable context which can be filled in with metadata + * @return the context filled in with metadata. It can be the same instance as + * provided or a new one. + */ + McpTransportContext extract(T request, McpTransportContext transportContext); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..d60e927f0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -1,11 +1,14 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.BufferedReader; import java.io.IOException; import java.io.PrintWriter; +import java.time.Duration; +import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -13,12 +16,17 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; @@ -97,12 +105,20 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement /** Map of active client sessions, keyed by session ID */ private final Map sessions = new ConcurrentHashMap<>(); + private McpTransportContextExtractor contextExtractor; + /** Flag indicating if the transport is in the process of shutting down */ private final AtomicBoolean isClosing = new AtomicBoolean(false); /** Session factory for creating new sessions */ private McpServerSession.Factory sessionFactory; + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + /** * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE * endpoint. @@ -110,7 +126,10 @@ public class HttpServletSseServerTransportProvider extends HttpServlet implement * serialization/deserialization * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); @@ -124,13 +143,74 @@ public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String m * @param baseUrl The base URL for the server transport * @param messageEndpoint The endpoint path where clients will send their messages * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. */ + @Deprecated public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, keepAliveInterval, null); + } + + /** + * Creates a new HttpServletSseServerTransportProvider instance with a custom SSE + * endpoint. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param baseUrl The base URL for the server transport + * @param messageEndpoint The endpoint path where clients will send their messages + * @param sseEndpoint The endpoint path where clients will establish SSE connections + * @param keepAliveInterval The interval for keep-alive pings, or null to disable + * keep-alive functionality + * @param contextExtractor The extractor for transport context from the request. + * @deprecated Use the builder {@link #builder()} instead for better configuration + * options. + */ + @Deprecated + public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, Duration keepAliveInterval, + McpTransportContextExtractor contextExtractor) { + this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing.get()) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); } /** @@ -286,10 +366,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) body.append(line); } + final McpTransportContext transportContext = contextExtractor.extract(request, + new DefaultMcpTransportContext()); McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); // Process the message through the session's handle method - session.handle(message).block(); // Block for Servlet compatibility + // Block for Servlet compatibility + session.handle(message).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); response.setStatus(HttpServletResponse.SC_OK); } @@ -324,7 +407,13 @@ public Mono closeGracefully() { isClosing.set(true); logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); } /** @@ -475,6 +564,10 @@ public static class Builder { private String sseEndpoint = DEFAULT_SSE_ENDPOINT; + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Duration keepAliveInterval; + /** * Sets the JSON object mapper to use for message serialization/deserialization. * @param objectMapper The object mapper to use @@ -522,6 +615,31 @@ public Builder sseEndpoint(String sseEndpoint) { return this; } + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public HttpServletSseServerTransportProvider.Builder contextExtractor( + McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the interval for keep-alive pings. + *

+ * If not specified, keep-alive pings will be disabled. + * @param keepAliveInterval The interval duration for keep-alive pings + * @return This builder instance for method chaining + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + /** * Builds a new instance of HttpServletSseServerTransportProvider with the * configured settings. @@ -535,7 +653,8 @@ public HttpServletSseServerTransportProvider build() { if (messageEndpoint == null) { throw new IllegalStateException("MessageEndpoint must be set"); } - return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new HttpServletSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint, + keepAliveInterval, contextExtractor); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java new file mode 100644 index 000000000..25b003564 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -0,0 +1,306 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Mono; + +/** + * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@WebServlet(asyncSupported = true) +public class HttpServletStatelessServerTransport extends HttpServlet implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStatelessServerTransport.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String ACCEPT = "Accept"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + private final ObjectMapper objectMapper; + + private final String mcpEndpoint; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private HttpServletStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Handles GET requests - returns 405 METHOD NOT ALLOWED as stateless transport + * doesn't support GET requests. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Both application/json and text/event-stream required in Accept header")); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse); + PrintWriter writer = response.getWriter(); + writer.write(jsonResponseText); + writer.flush(); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Sends an error response to the client. + * @param response The HTTP servlet response + * @param httpCode The HTTP status code + * @param mcpError The MCP error to send + * @throws IOException If an I/O error occurs + */ + private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

+ * This method ensures a graceful shutdown before calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link HttpServletStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * HttpServletStatelessServerTransport with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStatelessServerTransport} with the + * configured settings. + * @return A new HttpServletStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStatelessServerTransport build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new HttpServletStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java new file mode 100644 index 000000000..8b95ec607 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -0,0 +1,851 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.KeepAliveScheduler; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through HttpServlet. This implementation + * provides a bridge between synchronous HttpServlet operations and reactive programming + * patterns to maintain compatibility with the reactive transport interface. + * + *

+ * This is the HttpServlet equivalent of + * {@link io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider} + * for the core MCP module, providing streamable HTTP transport functionality without + * Spring dependencies. + * + * @author Zachary German + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see HttpServlet + */ +@WebServlet(asyncSupported = true) +public class HttpServletStreamableServerTransportProvider extends HttpServlet + implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Header name for the response media types accepted by the requester. + */ + private static final String ACCEPT = "Accept"; + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final ObjectMapper objectMapper; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Keep-alive scheduler for managing session pings. Activated if keepAliveInterval is + * set. Disabled by default. + */ + private KeepAliveScheduler keepAliveScheduler; + + /** + * Constructs a new HttpServletStreamableServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @param contextExtractor The extractor for transport context from the request. + * @throws IllegalArgumentException if any parameter is null + */ + private HttpServletStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor, + Duration keepAliveInterval) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "Context extractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + + if (keepAliveInterval != null) { + + this.keepAliveScheduler = KeepAliveScheduler + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .initialDelay(keepAliveInterval) + .interval(keepAliveInterval) + .build(); + + this.keepAliveScheduler.start(); + } + + } + + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05, ProtocolVersions.MCP_2025_03_26); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }).then().doOnSuccess(v -> { + sessions.clear(); + logger.debug("Graceful shutdown completed"); + if (this.keepAliveScheduler != null) { + this.keepAliveScheduler.shutdown(); + } + }); + } + + /** + * Handles GET requests to establish SSE connections and message replay. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + logger.debug("Handling GET request for session: {}", sessionId); + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + try { + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, asyncContext, response.getWriter()); + + // Check if this is a replay request + if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { + String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + asyncContext.complete(); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + asyncContext.complete(); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + asyncContext.addListener(new jakarta.servlet.AsyncListener() { + @Override + public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection timed out for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onError(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection error for session: {}", sessionId); + listeningStream.close(); + } + + @Override + public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { + // No action needed + } + }); + } + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + List badRequestErrors = new ArrayList<>(); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !accept.contains(TEXT_EVENT_STREAM)) { + badRequestErrors.add("text/event-stream required in Accept header"); + } + if (accept == null || !accept.contains(APPLICATION_JSON)) { + badRequestErrors.add("application/json required in Accept header"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setHeader(HttpHeaders.MCP_SESSION_ID, init.session().getId()); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponse = objectMapper.writeValueAsString(new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, null)); + + PrintWriter writer = response.getWriter(); + writer.write(jsonResponse); + writer.flush(); + return; + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to initialize session: " + e.getMessage())); + return; + } + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + + if (sessionId == null || sessionId.isBlank()) { + badRequestErrors.add("Session ID required in mcp-session-id header"); + } + + if (!badRequestErrors.isEmpty()) { + String combinedMessage = String.join("; ", badRequestErrors); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError(combinedMessage)); + return; + } + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + this.responseError(response, HttpServletResponse.SC_NOT_FOUND, + new McpError("Session not found: " + sessionId)); + return; + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader("Cache-Control", "no-cache"); + response.setHeader("Connection", "keep-alive"); + response.setHeader("Access-Control-Allow-Origin", "*"); + + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); + + HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( + sessionId, asyncContext, response.getWriter()); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + asyncContext.complete(); + } + } + else { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Invalid message format: " + e.getMessage())); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + try { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Error processing message: " + e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error processing message"); + } + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (this.isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + if (this.disallowDelete) { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + if (request.getHeader(HttpHeaders.MCP_SESSION_ID) == null) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Session ID required in mcp-session-id header")); + return; + } + + String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + response.setStatus(HttpServletResponse.SC_OK); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + try { + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError(e.getMessage())); + } + catch (IOException ex) { + logger.error(FAILED_TO_SEND_ERROR_RESPONSE, ex.getMessage()); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error deleting session"); + } + } + } + + public void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + return; + } + + /** + * Sends an SSE event to a client with a specific ID. + * @param writer The writer to send the event through + * @param eventType The type of event (message or endpoint) + * @param data The event data + * @param id The event ID + * @throws IOException If an error occurs while writing the event + */ + private void sendEvent(PrintWriter writer, String eventType, String data, String id) throws IOException { + if (id != null) { + writer.write("id: " + id + "\n"); + } + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

+ * This method ensures a graceful shutdown by closing all client connections before + * calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Implementation of McpStreamableServerTransport for HttpServlet SSE sessions. This + * class handles the transport-level communication for a specific client session. + * + *

+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying PrintWriter to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + + private class HttpServletStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final AsyncContext asyncContext; + + private final PrintWriter writer; + + private volatile boolean closed = false; + + private final ReentrantLock lock = new ReentrantLock(); + + /** + * Creates a new session transport with the specified ID and SSE writer. + * @param sessionId The unique identifier for this session + * @param asyncContext The async context for the session + * @param writer The writer for sending server events to the client + */ + HttpServletStreamableMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) { + this.sessionId = sessionId; + this.asyncContext = asyncContext; + this.writer = writer; + logger.debug("Streamable session transport {} initialized with SSE writer", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = objectMapper.writeValueAsString(message); + HttpServletStreamableServerTransportProvider.this.sendEvent(writer, MESSAGE_EVENT_TYPE, jsonText, + messageId != null ? messageId : this.sessionId); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + } + finally { + lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + HttpServletStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + this.asyncContext.complete(); + logger.debug("Successfully completed async context for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete async context for session {}: {}", sessionId, e.getMessage()); + } + finally { + lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of + * {@link HttpServletStreamableServerTransportProvider}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Duration keepAliveInterval; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor for extracting transport context from the request. + * @param contextExtractor The context extractor to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Sets the keep-alive interval for the transport. If set, a keep-alive scheduler + * will be activated to periodically ping active sessions. + * @param keepAliveInterval The interval for keep-alive pings. If null, no + * keep-alive will be scheduled. + * @return this builder instance + */ + public Builder keepAliveInterval(Duration keepAliveInterval) { + this.keepAliveInterval = keepAliveInterval; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} + * with the configured settings. + * @return A new HttpServletStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStreamableServerTransportProvider build() { + Assert.notNull(this.objectMapper, "ObjectMapper must be set"); + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + + return new HttpServletStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, + this.disallowDelete, this.contextExtractor, this.keepAliveInterval); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 9ef9c7829..af602f610 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -10,6 +10,7 @@ import java.io.InputStreamReader; import java.io.OutputStream; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; @@ -22,6 +23,7 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,6 +90,11 @@ public StdioServerTransportProvider(ObjectMapper objectMapper, InputStream input this.outputStream = outputStream; } + @Override + public List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + @Override public void setSessionFactory(McpServerSession.Factory sessionFactory) { // Create a single session for the stdio connection diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java index cd8fc9659..f4bdc02eb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidator.java @@ -1,6 +1,7 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.spec; import java.util.Map; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java new file mode 100644 index 000000000..f497afd43 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; + +/** + * A default implementation of {@link McpStreamableServerSession.Factory}. + * + * @author Dariusz Jędrzejczyk + */ +public class DefaultMcpStreamableServerSessionFactory implements McpStreamableServerSession.Factory { + + Duration requestTimeout; + + McpStreamableServerSession.InitRequestHandler initRequestHandler; + + Map> requestHandlers; + + Map notificationHandlers; + + /** + * Constructs an instance + * @param requestTimeout timeout for requests + * @param initRequestHandler initialization request handler + * @param requestHandlers map of MCP request handlers keyed by method name + * @param notificationHandlers map of MCP notification handlers keyed by method name + */ + public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, + McpStreamableServerSession.InitRequestHandler initRequestHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initRequestHandler = initRequestHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession( + McpSchema.InitializeRequest initializeRequest) { + return new McpStreamableServerSession.McpStreamableServerSessionInit( + new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), + initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers), + this.initRequestHandler.handle(initializeRequest)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java index 56cdeaf7f..fdb7bfd89 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java index eb2b7edeb..8d63fb50d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportStream.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java b/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java new file mode 100644 index 000000000..65b80957c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/HttpHeaders.java @@ -0,0 +1,29 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * Names of HTTP headers in use by MCP HTTP transports. + * + * @author Dariusz Jędrzejczyk + */ +public interface HttpHeaders { + + /** + * Identifies individual MCP sessions. + */ + String MCP_SESSION_ID = "mcp-session-id"; + + /** + * Identifies events within an SSE Stream. + */ + String LAST_EVENT_ID = "Last-Event-ID"; + + /** + * Identifies the MCP protocol version. + */ + String PROTOCOL_VERSION = "MCP-Protocol-Version"; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java b/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java index c95e627a9..572d7c043 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/JsonSchemaValidator.java @@ -1,6 +1,7 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.spec; import java.util.Map; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index cc7d2abf8..f7db3d7aa 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -221,7 +221,7 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { var handler = notificationHandlers.get(notification.method()); if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); + logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } return handler.handle(notification.params()); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java index 5c3b33131..22aec831b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientTransport.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.spec; import java.util.function.Consumer; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java index 13e43240b..6172d8637 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpError.java @@ -1,9 +1,11 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.spec; import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse.JSONRPCError; +import io.modelcontextprotocol.util.Assert; public class McpError extends RuntimeException { @@ -14,6 +16,7 @@ public McpError(JSONRPCError jsonRpcError) { this.jsonRpcError = jsonRpcError; } + @Deprecated public McpError(Object error) { super(error.toString()); } @@ -22,4 +25,37 @@ public JSONRPCError getJsonRpcError() { return jsonRpcError; } + public static Builder builder(int errorCode) { + return new Builder(errorCode); + } + + public static class Builder { + + private final int code; + + private String message; + + private Object data; + + private Builder(int code) { + this.code = code; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder data(Object data) { + this.data = data; + return this; + } + + public McpError build() { + Assert.hasText(message, "message must not be empty"); + return new McpError(new JSONRPCError(code, message, data)); + } + + } + } \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java new file mode 100644 index 000000000..f43a2c1d9 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java @@ -0,0 +1,29 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +/** + * An {@link McpSession} which is capable of processing logging notifications and keeping + * track of a min logging level. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpLoggableSession extends McpSession { + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel); + + /** + * Allows checking whether a particular logging level is allowed. + * @param loggingLevel the level to check + * @return whether the logging at the specified level is permitted. + */ + boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 5ba28cabd..bd8a01555 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -44,7 +44,8 @@ public final class McpSchema { private McpSchema() { } - public static final String LATEST_PROTOCOL_VERSION = "2024-11-05"; + @Deprecated + public static final String LATEST_PROTOCOL_VERSION = ProtocolVersions.MCP_2025_03_26; public static final String JSONRPC_VERSION = "2.0"; @@ -236,6 +237,17 @@ public record JSONRPCRequest( // @formatter:off @JsonProperty("method") String method, @JsonProperty("id") Object id, @JsonProperty("params") Object params) implements JSONRPCMessage { // @formatter:on + + /** + * Constructor that validates MCP-specific ID requirements. Unlike base JSON-RPC, + * MCP requires that: (1) Requests MUST include a string or integer ID; (2) The ID + * MUST NOT be null + */ + public JSONRPCRequest { + Assert.notNull(id, "MCP requests MUST include an ID - null IDs are not allowed"); + Assert.isTrue(id instanceof String || id instanceof Integer || id instanceof Long, + "MCP requests MUST have an ID that is either a string or integer"); + } } /** @@ -511,6 +523,21 @@ public record ResourceCapabilities(@JsonProperty("subscribe") Boolean subscribe, public record ToolCapabilities(@JsonProperty("listChanged") Boolean listChanged) { } + /** + * Create a mutated copy of this object with the specified changes. + * @return A new Builder instance with the same values as this object. + */ + public Builder mutate() { + var builder = new Builder(); + builder.completions = this.completions; + builder.experimental = this.experimental; + builder.logging = this.logging; + builder.prompts = this.prompts; + builder.resources = this.resources; + builder.tools = this.tools; + return builder; + } + public static Builder builder() { return new Builder(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..669c10b83 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import java.time.Duration; @@ -9,6 +13,11 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -16,10 +25,10 @@ import reactor.core.publisher.Sinks; /** - * Represents a Model Control Protocol (MCP) session on the server side. It manages + * Represents a Model Context Protocol (MCP) session on the server side. It manages * bidirectional JSON-RPC communication with the client. */ -public class McpServerSession implements McpSession { +public class McpServerSession implements McpLoggableSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); @@ -32,13 +41,11 @@ public class McpServerSession implements McpSession { private final AtomicLong requestCounter = new AtomicLong(0); - private final InitRequestHandler initRequestHandler; - - private final InitNotificationHandler initNotificationHandler; + private final McpInitRequestHandler initRequestHandler; - private final Map> requestHandlers; + private final Map> requestHandlers; - private final Map notificationHandlers; + private final Map notificationHandlers; private final McpServerTransport transport; @@ -56,6 +63,29 @@ public class McpServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + McpInitRequestHandler initHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.transport = transport; + this.initRequestHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id @@ -68,15 +98,18 @@ public class McpServerSession implements McpSession { * received. * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use + * @deprecated Use + * {@link #McpServerSession(String, Duration, McpServerTransport, McpInitRequestHandler, Map, Map)} */ + @Deprecated public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, + Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; this.initRequestHandler = initHandler; - this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; } @@ -108,6 +141,17 @@ private String generateRequestId() { return this.id + "-" + this.requestCounter.getAndIncrement(); } + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); @@ -154,7 +198,9 @@ public Mono sendNotification(String method, Object params) { * @return a Mono that completes when the message is processed */ public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // TODO handle errors for communication to without initialization happening // first if (message instanceof McpSchema.JSONRPCResponse response) { @@ -170,7 +216,7 @@ public Mono handle(McpSchema.JSONRPCMessage message) { } else if (message instanceof McpSchema.JSONRPCRequest request) { logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { + return handleIncomingRequest(request, transportContext).onErrorResume(error -> { var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, error.getMessage(), null)); @@ -183,7 +229,7 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { // happening first logger.debug("Received notification: {}", notification); // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) + return handleIncomingNotification(notification, transportContext) .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { @@ -196,9 +242,11 @@ else if (message instanceof McpSchema.JSONRPCNotification notification) { /** * Handles an incoming JSON-RPC request by routing it to the appropriate handler. * @param request The incoming JSON-RPC request + * @param transportContext * @return A Mono containing the JSON-RPC response */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request, + McpTransportContext transportContext) { return Mono.defer(() -> { Mono resultMono; if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { @@ -222,7 +270,11 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR error.message(), error.data()))); } - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + resultMono = this.exchangeSink.asMono().flatMap(exchange -> { + McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this, + exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext); + return handler.handle(newExchange, request.params()); + }); } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) @@ -236,22 +288,30 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR /** * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. * @param notification The incoming JSON-RPC notification + * @param transportContext * @return A Mono that completes when the notification is processed */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification, + McpTransportContext transportContext) { return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); - return this.initNotificationHandler.handle(); + // FIXME: The session ID passed here is not the same as the one in the + // legacy SSE transport. + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), + clientInfo.get(), transportContext)); } var handler = notificationHandlers.get(notification.method()); if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); + logger.warn("No handler registered for notification method: {}", notification); return Mono.empty(); } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + return this.exchangeSink.asMono().flatMap(exchange -> { + McpAsyncServerExchange newExchange = new McpAsyncServerExchange(exchange.sessionId(), this, + exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext); + return handler.handle(newExchange, notification.params()); + }); }); } @@ -264,17 +324,22 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { + // TODO: clear pendingResponses and emit errors? return this.transport.closeGracefully(); } @Override public void close() { + // TODO: clear pendingResponses and emit errors? this.transport.close(); } /** * Request handler for the initialization request. + * + * @deprecated Use {@link McpInitRequestHandler} */ + @Deprecated public interface InitRequestHandler { /** @@ -301,7 +366,10 @@ public interface InitNotificationHandler { /** * A handler for client-initiated notifications. + * + * @deprecated Use {@link McpNotificationHandler} */ + @Deprecated public interface NotificationHandler { /** @@ -320,7 +388,9 @@ public interface NotificationHandler { * * @param the type of the response that is expected as a result of handling the * request. + * @deprecated Use {@link McpRequestHandler} */ + @Deprecated public interface RequestHandler { /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java index 632b8cee6..39c1644e0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 5fdbd7ab6..02028ccdf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -1,35 +1,16 @@ -package io.modelcontextprotocol.spec; - -import java.util.Map; +/* + * Copyright 2024-2025 the original author or authors. + */ -import reactor.core.publisher.Mono; +package io.modelcontextprotocol.spec; /** - * The core building block providing the server-side MCP transport. Implement this - * interface to bridge between a particular server-side technology and the MCP server - * transport layer. - * - *

- * The lifecycle of the provider dictates that it be created first, upon application - * startup, and then passed into either - * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or - * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As - * a result of the MCP server creation, the provider will be notified of a - * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication - * between a newly connected client and the server. The provider's responsibility is to - * create instances of {@link McpServerTransport} that the session will utilise during the - * session lifetime. - * - *

- * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or - * {@link #closeGracefully()} are called as part of the normal application shutdown event. - * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where - * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} - * closes the provided transport. + * Classic implementation of {@link McpServerTransportProviderBase} for a single outgoing + * stream in bidirectional communication (STDIO and the legacy HTTP SSE). * * @author Dariusz Jędrzejczyk */ -public interface McpServerTransportProvider { +public interface McpServerTransportProvider extends McpServerTransportProviderBase { /** * Sets the session factory that will be used to create sessions for new clients. An @@ -39,28 +20,4 @@ public interface McpServerTransportProvider { */ void setSessionFactory(McpServerSession.Factory sessionFactory); - /** - * Sends a notification to all connected clients. - * @param method the name of the notification method to be called on the clients - * @param params parameters to be sent with the notification - * @return a Mono that completes when the notification has been broadcast - * @see McpSession#sendNotification(String, Map) - */ - Mono notifyClients(String method, Object params); - - /** - * Immediately closes all the transports with connected clients and releases any - * associated resources. - */ - default void close() { - this.closeGracefully().subscribe(); - } - - /** - * Gracefully closes all the transports with connected clients and releases any - * associated resources asynchronously. - * @return a {@link Mono} that completes when the connections have been closed. - */ - Mono closeGracefully(); - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java new file mode 100644 index 000000000..acb1ecac6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java @@ -0,0 +1,71 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.List; +import java.util.Map; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

+ * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProviderBase { + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + + /** + * Returns the protocol version supported by this transport provider. + * @return the protocol version as a string + */ + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 42d170db5..3473a4da8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -8,7 +8,7 @@ import reactor.core.publisher.Mono; /** - * Represents a Model Control Protocol (MCP) session that handles communication between + * Represents a Model Context Protocol (MCP) session that handles communication between * clients and the server. This interface provides methods for sending requests and * notifications, as well as managing the session lifecycle. * diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java new file mode 100644 index 000000000..c1234b130 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.List; + +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import reactor.core.publisher.Mono; + +public interface McpStatelessServerTransport { + + void setMcpHandler(McpStatelessServerHandler mcpHandler); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2025_03_26); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java new file mode 100644 index 000000000..ef7967c1e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -0,0 +1,403 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +/** + * Representation of a Streamable HTTP server session that keeps track of mapping + * server-initiated requests to the client and mapping arriving responses. It also allows + * handling incoming notifications. For requests, it provides the default SSE streaming + * capability without the insight into the transport-specific details of HTTP handling. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStreamableServerSession implements McpLoggableSession { + + private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class); + + private final ConcurrentHashMap requestIdToStream = new ConcurrentHashMap<>(); + + private final String id; + + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private final AtomicReference listeningStreamRef; + + private final MissingMcpTransportSession missingMcpTransportSession; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Create an instance of the streamable session. + * @param id session ID + * @param clientCapabilities client capabilities + * @param clientInfo client info + * @param requestTimeout timeout to use for requests + * @param requestHandlers the map of MCP request handlers keyed by method name + * @param notificationHandlers the map of MCP notification handlers keyed by method + * name + */ + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, Duration requestTimeout, + Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.missingMcpTransportSession = new MissingMcpTransportSession(id); + this.listeningStreamRef = new AtomicReference<>(this.missingMcpTransportSession); + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + this.requestTimeout = requestTimeout; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + + /** + * Return the Session ID. + * @return session ID + */ + public String getId() { + return this.id; + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendRequest(method, requestParams, typeRef); + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendNotification(method, params); + }); + } + + public Mono delete() { + return this.closeGracefully().then(Mono.fromRunnable(() -> { + // TODO: review in the context of history storage + // delete history, etc. + })); + } + + /** + * Create a listening stream (the generic HTTP GET request without Last-Event-ID + * header). + * @param transport The dedicated SSE transport stream + * @return a stream representation + */ + public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); + this.listeningStreamRef.set(listeningStream); + return listeningStream; + } + + // TODO: keep track of history by keeping a map from eventId to stream and then + // iterate over the events using the lastEventId + public Flux replay(Object lastEventId) { + return Flux.empty(); + } + + /** + * Provide the SSE stream of MCP messages finalized with a Response. + * @param jsonrpcRequest the MCP request triggering the stream creation + * @param transport the SSE transport stream to send messages to + * @return Mono which completes once the processing is done + */ + public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers + .get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close + // remove itself from the registry and also close the underlying transport + // (sink) + if (requestHandler == null) { + MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); + return transport + .sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + return requestHandler + .handle(new McpAsyncServerExchange(this.id, stream, clientCapabilities.get(), clientInfo.get(), + transportContext), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, + null)) + .onErrorResume(e -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), null)); + return Mono.just(errorResponse); + }) + .flatMap(transport::sendMessage) + .then(transport.closeGracefully()); + }); + } + + /** + * Handle the MCP notification. + * @param notification MCP notification + * @return Mono which completes upon succesful handling + */ + public Mono accept(McpSchema.JSONRPCNotification notification) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.warn("No handler registered for notification method: {}", notification); + return Mono.empty(); + } + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, + this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); + }); + + } + + /** + * Handle the MCP response. + * @param response MCP response to the server-initiated request + * @return Mono which completes upon successful processing + */ + public Mono accept(McpSchema.JSONRPCResponse response) { + return Mono.defer(() -> { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO + // JSONize + } + // TODO: encapsulate this inside the stream itself + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO + // JSONize + } + else { + sink.success(response); + } + return Mono.empty(); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + return listeningStream.closeGracefully(); + // TODO: Also close all the open streams + }); + } + + @Override + public void close() { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + if (listeningStream != null) { + listeningStream.close(); + } + // TODO: Also close all open streams + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Factory for new Streamable HTTP MCP sessions. + */ + public interface Factory { + + /** + * Given an initialize request, create a composite for the session initialization + * @param initializeRequest the initialization request from the client + * @return a composite allowing the session to start + */ + McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Composite holding the {@link McpStreamableServerSession} and the initialization + * result + * + * @param session the session instance + * @param initResult the result to use to respond to the client + */ + public record McpStreamableServerSessionInit(McpStreamableServerSession session, + Mono initResult) { + } + + /** + * An individual SSE stream within a Streamable HTTP context. Can be either the + * listening GET SSE stream or a request-specific POST SSE stream. + */ + public final class McpStreamableServerSessionStream implements McpLoggableSession { + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final McpStreamableServerTransport transport; + + private final String transportId; + + private final Supplier uuidGenerator; + + /** + * Constructor accepting the dedicated transport representing the SSE stream. + * @param transport request-specific SSE transport stream + */ + public McpStreamableServerSessionStream(McpStreamableServerTransport transport) { + this.transport = transport; + this.transportId = UUID.randomUUID().toString(); + // This ID design allows for a constant-time extraction of the history by + // precisely identifying the SSE stream using the first component + this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID(); + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel); + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = McpStreamableServerSession.this.generateRequestId(); + + McpStreamableServerSession.this.requestIdToStream.put(requestId, this); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + method, requestId, requestParams); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> { + }, sink::error); + }).timeout(requestTimeout).doOnError(e -> { + this.pendingResponses.remove(requestId); + McpStreamableServerSession.this.requestIdToStream.remove(requestId); + }).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, method, params); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + return this.transport.sendMessage(jsonrpcNotification, messageId); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + return this.transport.closeGracefully(); + }); + } + + @Override + public void close() { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + this.transport.close(); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java new file mode 100644 index 000000000..f53c68900 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java @@ -0,0 +1,24 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * Streamable HTTP server transport representing an individual SSE stream. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransport extends McpServerTransport { + + /** + * Send a message to the client with a message ID for use in the SSE event payload + * @param message the JSON-RPC payload + * @param messageId message id for SSE events + * @return Mono which completes when done + */ + Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java new file mode 100644 index 000000000..09fe9fb0e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -0,0 +1,69 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport for Streamable HTTP + * servers. Implement this interface to bridge between a particular server-side technology + * and the MCP server transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} + * or + * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. + * As a result of the MCP server creation, the provider will be notified of a + * {@link McpStreamableServerSession.Factory} which will be used to handle a 1:1 + * communication between a newly connected client and the server using a session concept. + * The provider's responsibility is to create instances of + * {@link McpStreamableServerTransport} that the session will utilise during the session + * lifetime. + * + *

+ * Finally, the {@link McpStreamableServerTransport}s can be closed in bulk when + * {@link #close()} or {@link #closeGracefully()} are called as part of the normal + * application shutdown event. Individual {@link McpStreamableServerTransport}s can also + * be closed on a per-session basis, where the {@link McpServerSession#close()} or + * {@link McpServerSession#closeGracefully()} closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransportProvider extends McpServerTransportProviderBase { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpStreamableServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java index 40d9ba7ac..1922548a6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransport.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.spec; +import java.util.List; + import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import reactor.core.publisher.Mono; @@ -77,4 +79,8 @@ default void close() { */ T unmarshalFrom(Object data, TypeReference typeRef); + default List protocolVersions() { + return List.of(ProtocolVersions.MCP_2024_11_05); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java index 555f018f8..716ff0d16 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSession.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java index 474a18ae0..eced49ec3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportSessionNotFoundException.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java index 2d6dcce75..322afda63 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportStream.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; import org.reactivestreams.Publisher; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java new file mode 100644 index 000000000..aa33a8167 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * A {@link McpLoggableSession} which represents a missing stream that would allow the + * server to communicate with the client. Specifically, it can be used when a Streamable + * HTTP client has not opened a listening SSE stream to accept messages for interactions + * unrelated with concurrently running client-initiated requests. + * + * @author Dariusz Jędrzejczyk + */ +public class MissingMcpTransportSession implements McpLoggableSession { + + private final String sessionId; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Create an instance with the Session ID specified. + * @param sessionId session ID + */ + public MissingMcpTransportSession(String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java b/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java new file mode 100644 index 000000000..d8cb913a5 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/ProtocolVersions.java @@ -0,0 +1,23 @@ +package io.modelcontextprotocol.spec; + +public interface ProtocolVersions { + + /** + * MCP protocol version for 2024-11-05. + * https://modelcontextprotocol.io/specification/2024-11-05 + */ + String MCP_2024_11_05 = "2024-11-05"; + + /** + * MCP protocol version for 2025-03-26. + * https://modelcontextprotocol.io/specification/2025-03-26 + */ + String MCP_2025_03_26 = "2025-03-26"; + + /** + * MCP protocol version for 2025-06-18. + * https://modelcontextprotocol.io/specification/2025-06-18 + */ + String MCP_2025_06_18 = "2025-06-18"; + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java b/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java index d68188c6f..1fa6b3058 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Assert.java @@ -76,4 +76,17 @@ public static boolean hasText(@Nullable String str) { return (str != null && !str.isBlank()); } + /** + * Assert a boolean expression, throwing an {@code IllegalArgumentException} if the + * expression evaluates to {@code false}. + * @param expression a boolean expression + * @param message the exception message to use if the assertion fails + * @throws IllegalArgumentException if {@code expression} is {@code false} + */ + public static void isTrue(boolean expression, String message) { + if (!expression) { + throw new IllegalArgumentException(message); + } + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java index 3870b76fc..44ea31690 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/DeafaultMcpUriTemplateManagerFactory.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.util; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java b/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java new file mode 100644 index 000000000..9d411cd41 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/util/KeepAliveScheduler.java @@ -0,0 +1,217 @@ +/** + * Copyright 2025 - 2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +/** + * A utility class for scheduling regular keep-alive calls to maintain connections. It + * sends periodic keep-alive, ping, messages to connected mcp clients to prevent idle + * timeouts. + * + * The pings are sent to all active mcp sessions at regular intervals. + * + * @author Christian Tzolov + */ +public class KeepAliveScheduler { + + private static final Logger logger = LoggerFactory.getLogger(KeepAliveScheduler.class); + + private static final TypeReference OBJECT_TYPE_REF = new TypeReference<>() { + }; + + /** Initial delay before the first keepAlive call */ + private final Duration initialDelay; + + /** Interval between subsequent keepAlive calls */ + private final Duration interval; + + /** The scheduler used for executing keepAlive calls */ + private final Scheduler scheduler; + + /** The current state of the scheduler */ + private final AtomicBoolean isRunning = new AtomicBoolean(false); + + /** The current subscription for the keepAlive calls */ + private Disposable currentSubscription; + + // TODO Currently we do not support the streams (streamable http session created by + // http post/get) + + /** Supplier for reactive McpSession instances */ + private final Supplier> mcpSessions; + + /** + * Creates a KeepAliveScheduler with a custom scheduler, initial delay, interval and a + * supplier for McpSession instances. + * @param scheduler The scheduler to use for executing keepAlive calls + * @param initialDelay Initial delay before the first keepAlive call + * @param interval Interval between subsequent keepAlive calls + * @param mcpSessions Supplier for McpSession instances + */ + KeepAliveScheduler(Scheduler scheduler, Duration initialDelay, Duration interval, + Supplier> mcpSessions) { + this.scheduler = scheduler; + this.initialDelay = initialDelay; + this.interval = interval; + this.mcpSessions = mcpSessions; + } + + /** + * Creates a new Builder instance for constructing KeepAliveScheduler. + * @return A new Builder instance + */ + public static Builder builder(Supplier> mcpSessions) { + return new Builder(mcpSessions); + } + + /** + * Starts regular keepAlive calls with sessions supplier. + * @return Disposable to control the scheduled execution + */ + public Disposable start() { + if (this.isRunning.compareAndSet(false, true)) { + + this.currentSubscription = Flux.interval(this.initialDelay, this.interval, this.scheduler) + .doOnNext(tick -> { + this.mcpSessions.get() + .flatMap(session -> session.sendRequest(McpSchema.METHOD_PING, null, OBJECT_TYPE_REF) + .doOnError(e -> logger.warn("Failed to send keep-alive ping to session {}: {}", session, + e.getMessage())) + .onErrorComplete()) + .subscribe(); + }) + .doOnCancel(() -> this.isRunning.set(false)) + .doOnComplete(() -> this.isRunning.set(false)) + .onErrorComplete(error -> { + logger.error("KeepAlive scheduler error", error); + this.isRunning.set(false); + return true; + }) + .subscribe(); + + return this.currentSubscription; + } + else { + throw new IllegalStateException("KeepAlive scheduler is already running. Stop it first."); + } + } + + /** + * Stops the currently running keepAlive scheduler. + */ + public void stop() { + if (this.currentSubscription != null && !this.currentSubscription.isDisposed()) { + this.currentSubscription.dispose(); + } + this.isRunning.set(false); + } + + /** + * Checks if the scheduler is currently running. + * @return true if running, false otherwise + */ + public boolean isRunning() { + return this.isRunning.get(); + } + + /** + * Shuts down the scheduler and releases resources. + */ + public void shutdown() { + stop(); + if (this.scheduler instanceof Disposable) { + ((Disposable) this.scheduler).dispose(); + } + } + + /** + * Builder class for creating KeepAliveScheduler instances with fluent API. + */ + public static class Builder { + + private Scheduler scheduler = Schedulers.boundedElastic(); + + private Duration initialDelay = Duration.ofSeconds(0); + + private Duration interval = Duration.ofSeconds(30); + + private Supplier> mcpSessions; + + /** + * Creates a new Builder instance with a supplier for McpSession instances. + * @param mcpSessions The supplier for McpSession instances + */ + Builder(Supplier> mcpSessions) { + Assert.notNull(mcpSessions, "McpSessions supplier must not be null"); + this.mcpSessions = mcpSessions; + } + + /** + * Sets the scheduler to use for executing keepAlive calls. + * @param scheduler The scheduler to use: + *
    + *
  • Schedulers.single() - single-threaded scheduler
  • + *
  • Schedulers.boundedElastic() - bounded elastic scheduler for I/O operations + * (Default)
  • + *
  • Schedulers.parallel() - parallel scheduler for CPU-intensive + * operations
  • + *
  • Schedulers.immediate() - immediate scheduler for synchronous execution
  • + *
+ * @return This builder instance for method chaining + */ + public Builder scheduler(Scheduler scheduler) { + Assert.notNull(scheduler, "Scheduler must not be null"); + this.scheduler = scheduler; + return this; + } + + /** + * Sets the initial delay before the first keepAlive call. + * @param initialDelay The initial delay duration + * @return This builder instance for method chaining + */ + public Builder initialDelay(Duration initialDelay) { + Assert.notNull(initialDelay, "Initial delay must not be null"); + this.initialDelay = initialDelay; + return this; + } + + /** + * Sets the interval between subsequent keepAlive calls. + * @param interval The interval duration + * @return This builder instance for method chaining + */ + public Builder interval(Duration interval) { + Assert.notNull(interval, "Interval must not be null"); + this.interval = interval; + return this; + } + + /** + * Builds and returns a new KeepAliveScheduler instance. + * @return A new KeepAliveScheduler configured with the builder's settings + */ + public KeepAliveScheduler build() { + return new KeepAliveScheduler(scheduler, initialDelay, interval, mcpSessions); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java index 9644f9a6c..389727b45 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/McpUriTemplateManagerFactory.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.util; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 8e654e596..039b0d68e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -69,6 +69,9 @@ public static boolean isEmpty(@Nullable Map map) { * base URL or URI is malformed */ public static URI resolveUri(URI baseUrl, String endpointUrl) { + if (!Utils.hasText(endpointUrl)) { + return baseUrl; + } URI endpointUri = URI.create(endpointUrl); if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java index 482d0aac6..b1113a6d0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpClientTransport.java @@ -29,6 +29,8 @@ public class MockMcpClientTransport implements McpClientTransport { private final BiConsumer interceptor; + private String protocolVersion = McpSchema.LATEST_PROTOCOL_VERSION; + public MockMcpClientTransport() { this((t, msg) -> { }); @@ -38,6 +40,15 @@ public MockMcpClientTransport(BiConsumer protocolVersions() { + return List.of(protocolVersion); + } + public void simulateIncomingMessage(McpSchema.JSONRPCMessage message) { if (inbound.tryEmitNext(message).isFailure()) { throw new RuntimeException("Failed to process incoming message " + message); diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 7ba35bbf0..e955be89f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -1,18 +1,7 @@ /* -* Copyright 2025 - 2025 the original author or authors. -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* https://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Copyright 2025-2025 the original author or authors. + */ + package io.modelcontextprotocol; import io.modelcontextprotocol.spec.McpSchema; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java index b673ed612..ec23e21dc 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/AbstractMcpAsyncClientResiliencyTests.java @@ -1,6 +1,7 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.client; import eu.rekawek.toxiproxy.Proxy; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java index aa081b51b..aef2ab8dd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpAsyncClientTests.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import org.junit.jupiter.api.Timeout; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java index 8285f417f..7f00de60e 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/HttpClientStreamableHttpSyncClientTests.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import org.junit.jupiter.api.Timeout; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index e773c8381..b2fd7fb65 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -52,7 +52,7 @@ private static MockMcpClientTransport initializationEnabledTransport( r.id(), mockInitResult, null); t.simulateIncomingMessage(initResponse); } - }); + }).withProtocolVersion(McpSchema.LATEST_PROTOCOL_VERSION); } @Test @@ -79,7 +79,7 @@ void testSuccessfulInitialization() { // Verify initialization result assertThat(result).isNotNull(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); assertThat(result.capabilities()).isEqualTo(serverCapabilities); assertThat(result.serverInfo()).isEqualTo(serverInfo); assertThat(result.instructions()).isEqualTo("Test instructions"); diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java index 14ca82791..ae33898b7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpAsyncClientTests.java @@ -1,9 +1,15 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.client; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; + import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; @@ -21,7 +27,7 @@ class McpAsyncClientTests { .build(); public static final McpSchema.InitializeResult MOCK_INIT_RESULT = new McpSchema.InitializeResult( - McpSchema.LATEST_PROTOCOL_VERSION, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); + ProtocolVersions.MCP_2024_11_05, MOCK_SERVER_CAPABILITIES, MOCK_SERVER_INFO, "Test instructions"); private static final String CONTEXT_KEY = "context.key"; diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java index bf4738496..36216988f 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/McpClientProtocolVersionTests.java @@ -37,18 +37,20 @@ void shouldUseLatestVersionByDefault() { try { Mono initializeResultMono = client.initialize(); + String protocolVersion = transport.protocolVersions().get(transport.protocolVersions().size() - 1); + StepVerifier.create(initializeResultMono).then(() -> { McpSchema.JSONRPCRequest request = transport.getLastSentMessageAsRequest(); assertThat(request.params()).isInstanceOf(McpSchema.InitializeRequest.class); McpSchema.InitializeRequest initRequest = (McpSchema.InitializeRequest) request.params(); - assertThat(initRequest.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(initRequest.protocolVersion()).isEqualTo(transport.protocolVersions().get(0)); transport.simulateIncomingMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - new McpSchema.InitializeResult(McpSchema.LATEST_PROTOCOL_VERSION, null, + new McpSchema.InitializeResult(protocolVersion, null, new McpSchema.Implementation("test-server", "1.0.0"), null), null)); }).assertNext(result -> { - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); }).verifyComplete(); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index a101f0177..e9356d0c0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -37,7 +37,7 @@ protected McpClientTransport createMcpTransport() { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(20); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index 31430543a..46b9207f6 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -17,10 +17,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -28,9 +31,17 @@ import reactor.test.StepVerifier; import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.clearInvocations; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; /** * Tests for the {@link HttpClientSseClientTransport} class. @@ -43,7 +54,7 @@ class HttpClientSseClientTransportTests { static String host = "http://localhost:3001"; @SuppressWarnings("resource") - GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") .withCommand("node dist/index.js sse") .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) .withExposedPorts(3001) @@ -61,7 +72,7 @@ static class TestHttpClientSseClientTransport extends HttpClientSseClientTranspo public TestHttpClientSseClientTransport(final String baseUri) { super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(), HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse", - new ObjectMapper()); + new ObjectMapper(), AsyncHttpRequestCustomizer.NOOP); } public int getInboundMessageCount() { @@ -80,15 +91,21 @@ public void simulateMessageEvent(String jsonMessage) { } - void startContainer() { + @BeforeAll + static void startContainer() { container.start(); int port = container.getMappedPort(3001); host = "http://" + container.getHost() + ":" + port; + + } + + @AfterAll + static void stopContainer() { + container.stop(); } @BeforeEach void setUp() { - startContainer(); transport = new TestHttpClientSseClientTransport(host); transport.connect(Function.identity()).block(); } @@ -98,11 +115,6 @@ void afterEach() { if (transport != null) { assertThatCode(() -> transport.closeGracefully().block(Duration.ofSeconds(10))).doesNotThrowAnyException(); } - cleanup(); - } - - void cleanup() { - container.stop(); } @Test @@ -375,4 +387,74 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } + @Test + void testRequestCustomizer() { + var mockCustomizer = mock(SyncHttpRequestCustomizer.class); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .httpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + + @Test + void testAsyncRequestCustomizer() { + var mockCustomizer = mock(AsyncHttpRequestCustomizer.class); + when(mockCustomizer.customize(any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + // Create a transport with the customizer + var customizedTransport = HttpClientSseClientTransport.builder(host) + .asyncHttpRequestCustomizer(mockCustomizer) + .build(); + + // Connect + StepVerifier.create(customizedTransport.connect(Function.identity())).verifyComplete(); + + // Verify the customizer was called + verify(mockCustomizer).customize(any(), eq("GET"), + eq(UriComponentsBuilder.fromUriString(host).path("/sse").build().toUri()), isNull()); + clearInvocations(mockCustomizer); + + // Send test message + var testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id", + Map.of("key", "value")); + + // Subscribe to messages and verify + StepVerifier.create(customizedTransport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + var uriArgumentCaptor = ArgumentCaptor.forClass(URI.class); + verify(mockCustomizer).customize(any(), eq("POST"), uriArgumentCaptor.capture(), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"test-method\",\"id\":\"test-id\",\"params\":{\"key\":\"value\"}}")); + assertThat(uriArgumentCaptor.getValue().toString()).startsWith(host + "/message?sessionId="); + + // Clean up + customizedTransport.closeGracefully().block(); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java new file mode 100644 index 000000000..8b3668671 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportEmptyJsonResponseTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import com.sun.net.httpserver.HttpServer; + +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.ProtocolVersions; +import reactor.test.StepVerifier; + +/** + * Handles emplty application/json response with 200 OK status code. + * + * @author codezkk + */ +public class HttpClientStreamableHttpTransportEmptyJsonResponseTest { + + static int PORT = TomcatTestUtil.findAvailablePort(); + + static String host = "http://localhost:" + PORT; + + static HttpServer server; + + @BeforeAll + static void startContainer() throws IOException { + + server = HttpServer.create(new InetSocketAddress(PORT), 0); + + // Empty, 200 OK response for the /mcp endpoint + server.createContext("/mcp", exchange -> { + exchange.getResponseHeaders().set("Content-Type", "application/json"); + exchange.sendResponseHeaders(200, 0); + exchange.close(); + }); + + server.setExecutor(null); + server.start(); + } + + @AfterAll + static void stopContainer() { + server.stop(1); + } + + /** + * Regardless of the response (even if the response is null and the content-type is + * present), notify should handle it correctly. + */ + @Test + @Timeout(3) + void testNotificationInitialized() throws URISyntaxException { + + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(SyncHttpRequestCustomizer.class); + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + var initializeRequest = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2025_03_26, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java new file mode 100644 index 000000000..d645bb0b3 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientStreamableHttpTransportTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import io.modelcontextprotocol.spec.McpSchema; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the {@link HttpClientStreamableHttpTransport} class. + * + * @author Daniel Garnier-Moiroux + */ +class HttpClientStreamableHttpTransportTest { + + static String host = "http://localhost:3001"; + + @SuppressWarnings("resource") + static GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v2") + .withCommand("node dist/index.js streamableHttp") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @BeforeAll + static void startContainer() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @AfterAll + static void stopContainer() { + container.stop(); + } + + void withTransport(HttpClientStreamableHttpTransport transport, Consumer c) { + try { + c.accept(transport); + } + finally { + StepVerifier.create(transport.closeGracefully()).verifyComplete(); + } + } + + @Test + void testRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(SyncHttpRequestCustomizer.class); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .httpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + }); + } + + @Test + void testAsyncRequestCustomizer() throws URISyntaxException { + var uri = new URI(host + "/mcp"); + var mockRequestCustomizer = mock(AsyncHttpRequestCustomizer.class); + when(mockRequestCustomizer.customize(any(), any(), any(), any())) + .thenAnswer(invocation -> Mono.just(invocation.getArguments()[0])); + + var transport = HttpClientStreamableHttpTransport.builder(host) + .asyncHttpRequestCustomizer(mockRequestCustomizer) + .build(); + + withTransport(transport, (t) -> { + // Send test message + var initializeRequest = new McpSchema.InitializeRequest(McpSchema.LATEST_PROTOCOL_VERSION, + McpSchema.ClientCapabilities.builder().roots(true).build(), + new McpSchema.Implementation("Spring AI MCP Client", "0.3.1")); + var testMessage = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, McpSchema.METHOD_INITIALIZE, + "test-id", initializeRequest); + + StepVerifier.create(t.sendMessage(testMessage)).verifyComplete(); + + // Verify the customizer was called + verify(mockRequestCustomizer, atLeastOnce()).customize(any(), eq("POST"), eq(uri), eq( + "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"id\":\"test-id\",\"params\":{\"protocolVersion\":\"2025-03-26\",\"capabilities\":{\"roots\":{\"listChanged\":true}},\"clientInfo\":{\"name\":\"Spring AI MCP Client\",\"version\":\"0.3.1\"}}}")); + }); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index b5841e755..0ba8bf929 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -4,9 +4,17 @@ package io.modelcontextprotocol.server; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + import java.time.Duration; import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -18,16 +26,9 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - /** * Test suite for the {@link McpAsyncServer} that can be used with different * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. @@ -43,7 +44,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -63,29 +64,27 @@ void tearDown() { // --------------------------------------- // Server Lifecycle Tests // --------------------------------------- - - @Test void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -105,8 +104,7 @@ void testImmediateClose() { @Deprecated void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -120,8 +118,7 @@ void testAddTool() { @Test void testAddToolCall() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -138,8 +135,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -159,8 +155,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -181,8 +176,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! @@ -204,8 +198,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -216,8 +209,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) @@ -236,8 +228,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -249,8 +240,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -265,8 +255,7 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -282,7 +271,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -291,7 +280,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier .create(mcpAsyncServer @@ -303,8 +292,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -320,8 +308,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -336,9 +323,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -354,9 +339,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -370,7 +353,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -379,8 +362,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -393,9 +375,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( @@ -411,9 +391,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -430,8 +408,7 @@ void testRemovePrompt() { prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -443,8 +420,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -467,8 +443,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -487,8 +462,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -501,8 +475,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -514,9 +487,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 000000000..cc580bdd4 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1596 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import jakarta.servlet.http.HttpServletRequest; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + return Mono.just(mock(CallToolResult.class)); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + server.closeGracefully(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("1000ms"); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + // Create client without elicitation capabilities + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + assertWith(resultRef.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccessWithTranportContextExtraction(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var expectedCallResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=value")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + McpTransportContext transportContext = exchange.transportContext(); + assertTrue(transportContext != null, "transportContext should not be null"); + assertTrue(!transportContext.equals(McpTransportContext.EMPTY), "transportContext should not be empty"); + String ctxValue = (String) transportContext.get("important"); + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE; ctx=" + ctxValue)), null); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(expectedCallResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> toolsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + toolsRef.set(toolsUpdate); + } + catch (Exception e) { + e.printStackTrace(); + } + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(toolsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(toolsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + ; + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("logging-test") + .description("Test logging notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + mcpServer.close(); + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(McpSchema.Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(new CallToolResult(("Progress test completed"), false)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new McpSchema.CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new McpSchema.PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return new CallToolResult("Async ping test completed", false); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat(response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + + protected static McpTransportContextExtractor extractor = (r, tc) -> { + tc.put("important", "value"); + return tc; + }; + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 208d2e749..67579ce72 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -28,7 +28,7 @@ /** * Test suite for the {@link McpSyncServer} that can be used with different - * {@link McpTransportProvider} implementations. + * {@link McpServerTransportProvider} implementations. * * @author Christian Tzolov */ @@ -41,7 +41,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { } @@ -69,28 +69,28 @@ void testConstructorWithInvalidArguments() { .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -112,8 +112,7 @@ void testGetAsyncServer() { @Test @Deprecated void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -127,8 +126,7 @@ void testAddTool() { @Test void testAddToolCall() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -146,8 +144,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -164,8 +161,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); @@ -184,8 +180,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! @@ -207,8 +202,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -219,8 +213,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) @@ -239,8 +232,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -252,8 +244,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -265,7 +256,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -278,7 +269,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -287,7 +278,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) @@ -298,8 +289,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -315,8 +305,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -329,9 +318,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -344,9 +331,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -358,7 +343,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -367,8 +352,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -379,9 +363,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, @@ -394,9 +376,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -409,8 +389,7 @@ void testRemovePrompt() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -422,8 +401,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -443,8 +421,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -462,8 +439,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -475,8 +451,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -487,7 +462,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java new file mode 100644 index 000000000..4435b8b44 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletSseIntegrationTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; + +class HttpServletSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletSseServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(extractor) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_SSE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java new file mode 100644 index 000000000..4c3f22d76 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -0,0 +1,523 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.ProtocolVersions; +import net.javacrumbs.jsonunit.core.Option; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.web.client.RestClient; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.APPLICATION_JSON; +import static io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport.TEXT_EVENT_STREAM; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +class HttpServletStatelessIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletStatelessServerTransport mcpStatelessServerTransport; + + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + private Tomcat tomcat; + + @BeforeEach + public void before() { + this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpStatelessServerTransport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + } + + @AfterEach + public void after() { + if (mcpStatelessServerTransport != null) { + mcpStatelessServerTransport.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( + new Tool("tool1", "tool1 description", emptyJsonSchema), (transportContext, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport).build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (transportContext, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpStatelessServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (transportContext, getPromptRequest) -> null)) + .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( + new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .instructions("bla") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + dynamicTool, (transportContext, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + @Test + void testThrownMcpError() throws Exception { + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + Tool testTool = Tool.builder().name("test").description("test").build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + testTool, (transportContext, request) -> { + throw new McpError(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + }); + + mcpServer.addTool(toolSpec); + + McpSchema.CallToolRequest callToolRequest = new McpSchema.CallToolRequest("test", Map.of()); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_TOOLS_CALL, "test", callToolRequest); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", CUSTOM_MESSAGE_ENDPOINT); + MockHttpServletResponse response = new MockHttpServletResponse(); + + byte[] content = new ObjectMapper().writeValueAsBytes(jsonrpcRequest); + request.setContent(content); + request.addHeader("Content-Type", "application/json"); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Content-Length", Integer.toString(content.length)); + request.addHeader("Accept", APPLICATION_JSON + ", " + TEXT_EVENT_STREAM); + request.addHeader("Content-Type", APPLICATION_JSON); + request.addHeader("Cache-Control", "no-cache"); + request.addHeader(HttpHeaders.PROTOCOL_VERSION, ProtocolVersions.MCP_2025_03_26); + mcpStatelessServerTransport.service(request, response); + + McpSchema.JSONRPCResponse jsonrpcResponse = new ObjectMapper().readValue(response.getContentAsByteArray(), + McpSchema.JSONRPCResponse.class); + + assertThat(jsonrpcResponse.error()) + .isEqualTo(new McpSchema.JSONRPCResponse.JSONRPCError(12345, "testing", Map.of("a", "b"))); + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java new file mode 100644 index 000000000..327ec1b21 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableAsyncServerTests.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.junit.jupiter.api.Timeout; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; + +/** + * Tests for {@link McpAsyncServer} using + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpServletStreamableAsyncServerTests extends AbstractMcpAsyncServerTests { + + protected McpStreamableServerTransportProvider createMcpTransportProvider() { + return HttpServletStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint("/mcp/message") + .build(); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java new file mode 100644 index 000000000..0815556b9 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableIntegrationTests.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; + +class HttpServletStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private HttpServletStreamableServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + @BeforeEach + public void before() { + // Create and configure the transport provider + mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(extractor) + .mcpEndpoint(MESSAGE_ENDPOINT) + .keepAliveInterval(Duration.ofSeconds(1)) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofHours(10))); + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java new file mode 100644 index 000000000..66fa2b2ac --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStreamableSyncServerTests.java @@ -0,0 +1,35 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.junit.jupiter.api.Timeout; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; + +/** + * Tests for {@link McpSyncServer} using + * {@link HttpServletStreamableServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) +class HttpServletStreamableSyncServerTests extends AbstractMcpSyncServerTests { + + protected McpStreamableServerTransportProvider createMcpTransportProvider() { + return HttpServletStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint("/mcp/message") + .build(); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 39066a9a2..987c43663 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -25,7 +25,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -54,7 +53,8 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - exchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + exchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, + new DefaultMcpTransportContext()); } @Test @@ -219,27 +219,33 @@ void testLoggingNotificationWithNullMessage() { } @Test - void testLoggingNotificationWithAllowedLevel() { + void testSetMinLoggingLevelWithNullValue() { + assertThatThrownBy(() -> exchange.setMinLoggingLevel(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("minLoggingLevel must not be null"); + } + @Test + void testLoggingNotificationWithAllowedLevel() { McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.ERROR) .logger("test-logger") .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); - // Verify that sendNotification was called exactly once + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); } @Test void testLoggingNotificationWithFilteredLevel() { - // Given - Set minimum level to WARNING, send DEBUG message - exchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); + exchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(eq(McpSchema.LoggingLevel.DEBUG)); McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.DEBUG) @@ -247,104 +253,38 @@ void testLoggingNotificationWithFilteredLevel() { .data("Debug message that should be filtered") .build(); - // When & Then - Should complete without sending notification - StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); - - // Verify that sendNotification was never called for filtered DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - } - - @Test - void testLoggingNotificationLevelFiltering() { - // Given - Set minimum level to WARNING - exchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - // Test DEBUG (should be filtered) - McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build(); + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); - // Verify that sendNotification was never called for DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - - // Test INFO (should be filtered) - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - StepVerifier.create(exchange.loggingNotification(infoNotification)).verifyComplete(); - - // Verify that sendNotification was never called for INFO level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); - - reset(mockSession); + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); - // Test WARNING (should be sent) McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.WARNING) .logger("test-logger") - .data("Warning message") + .data("Debug message that should be filtered") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification))) - .thenReturn(Mono.empty()); - StepVerifier.create(exchange.loggingNotification(warningNotification)).verifyComplete(); - // Verify that sendNotification was called exactly once for WARNING level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.WARNING)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification)); - - // Test ERROR (should be sent) - McpSchema.LoggingMessageNotification errorNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification))) - .thenReturn(Mono.empty()); - - StepVerifier.create(exchange.loggingNotification(errorNotification)).verifyComplete(); - - // Verify that sendNotification was called exactly once for ERROR level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(errorNotification)); - } - - @Test - void testLoggingNotificationWithDefaultLevel() { - - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification))) - .thenReturn(Mono.empty()); - - StepVerifier.create(exchange.loggingNotification(infoNotification)).verifyComplete(); - - // Verify that sendNotification was called exactly once for default level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); } @Test void testLoggingNotificationWithSessionError() { - McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.ERROR) .logger("test-logger") .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.error(new RuntimeException("Session error"))); @@ -353,44 +293,6 @@ void testLoggingNotificationWithSessionError() { }); } - @Test - void testSetMinLoggingLevelWithNullValue() { - // When & Then - assertThatThrownBy(() -> exchange.setMinLoggingLevel(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("minLoggingLevel must not be null"); - } - - @Test - void testLoggingLevelHierarchy() { - // Test all logging levels to ensure proper hierarchy - McpSchema.LoggingLevel[] levels = { McpSchema.LoggingLevel.DEBUG, McpSchema.LoggingLevel.INFO, - McpSchema.LoggingLevel.NOTICE, McpSchema.LoggingLevel.WARNING, McpSchema.LoggingLevel.ERROR, - McpSchema.LoggingLevel.CRITICAL, McpSchema.LoggingLevel.ALERT, McpSchema.LoggingLevel.EMERGENCY }; - - // Set minimum level to WARNING - exchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - for (McpSchema.LoggingLevel level : levels) { - McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message for " + level) - .build(); - - if (level.level() >= McpSchema.LoggingLevel.WARNING.level()) { - // Should be sent - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) - .thenReturn(Mono.empty()); - - StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); - } - else { - // Should be filtered (completes without sending) - StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); - } - } - } - // --------------------------------------- // Create Elicitation Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java index 26b75946b..e329188f9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.server; import java.util.List; @@ -41,6 +45,8 @@ class McpCompletionTests { private HttpServletSseServerTransportProvider mcpServerTransportProvider; + private static final int PORT = TomcatTestUtil.findAvailablePort(); + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; McpClient.SyncSpec clientBuilder; @@ -55,7 +61,7 @@ public void before() { .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) .build(); - tomcat = TomcatTestUtil.createTomcatServer("", 3400, mcpServerTransportProvider); + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); try { tomcat.start(); assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); @@ -64,7 +70,7 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + 3400).build()); + this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); } @AfterEach diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java index f643f1ba3..cdd2bacb7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpServerProtocolVersionTests.java @@ -45,7 +45,9 @@ void shouldUseLatestVersionByDefault() { assertThat(jsonResponse.id()).isEqualTo(requestId); assertThat(jsonResponse.result()).isInstanceOf(McpSchema.InitializeResult.class); McpSchema.InitializeResult result = (McpSchema.InitializeResult) jsonResponse.result(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + + var protocolVersion = transportProvider.protocolVersions().get(transportProvider.protocolVersions().size() - 1); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); server.closeGracefully().subscribe(); } @@ -93,7 +95,8 @@ void shouldSuggestLatestVersionForUnsupportedVersion() { assertThat(jsonResponse.id()).isEqualTo(requestId); assertThat(jsonResponse.result()).isInstanceOf(McpSchema.InitializeResult.class); McpSchema.InitializeResult result = (McpSchema.InitializeResult) jsonResponse.result(); - assertThat(result.protocolVersion()).isEqualTo(McpSchema.LATEST_PROTOCOL_VERSION); + var protocolVersion = transportProvider.protocolVersions().get(transportProvider.protocolVersions().size() - 1); + assertThat(result.protocolVersion()).isEqualTo(protocolVersion); server.closeGracefully().subscribe(); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index 66d7695e8..a73ec7209 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -24,7 +24,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -226,19 +225,21 @@ void testLoggingNotificationWithAllowedLevel() { .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.empty()); exchange.loggingNotification(notification); // Verify that sendNotification was called exactly once + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); } @Test void testLoggingNotificationWithFilteredLevel() { - // Given - Set minimum level to WARNING, send DEBUG message - asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); + asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.DEBUG) @@ -246,93 +247,27 @@ void testLoggingNotificationWithFilteredLevel() { .data("Debug message that should be filtered") .build(); - // When & Then - Should complete without sending notification - exchange.loggingNotification(debugNotification); - - // Verify that sendNotification was never called for filtered DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - } - - @Test - void testLoggingNotificationLevelFiltering() { - // Given - Set minimum level to WARNING - asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - // Test DEBUG (should be filtered) - McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build(); + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); exchange.loggingNotification(debugNotification); - // Verify that sendNotification was never called for DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - - // Test INFO (should be filtered) - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - exchange.loggingNotification(infoNotification); - - // Verify that sendNotification was never called for INFO level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); - - reset(mockSession); + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); - // Test WARNING (should be sent) McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.WARNING) .logger("test-logger") - .data("Warning message") + .data("Debug message that should be filtered") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification))) - .thenReturn(Mono.empty()); - exchange.loggingNotification(warningNotification); - // Verify that sendNotification was called exactly once for WARNING level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.WARNING); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification)); - - // Test ERROR (should be sent) - McpSchema.LoggingMessageNotification errorNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification))) - .thenReturn(Mono.empty()); - - exchange.loggingNotification(errorNotification); - - // Verify that sendNotification was called exactly once for ERROR level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(errorNotification)); - } - - @Test - void testLoggingNotificationWithDefaultLevel() { - - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification))) - .thenReturn(Mono.empty()); - - exchange.loggingNotification(infoNotification); - - // Verify that sendNotification was called exactly once for default level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); } @Test @@ -344,6 +279,7 @@ void testLoggingNotificationWithSessionError() { .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.error(new RuntimeException("Session error"))); @@ -351,37 +287,6 @@ void testLoggingNotificationWithSessionError() { .hasMessage("Session error"); } - @Test - void testLoggingLevelHierarchy() { - // Test all logging levels to ensure proper hierarchy - McpSchema.LoggingLevel[] levels = { McpSchema.LoggingLevel.DEBUG, McpSchema.LoggingLevel.INFO, - McpSchema.LoggingLevel.NOTICE, McpSchema.LoggingLevel.WARNING, McpSchema.LoggingLevel.ERROR, - McpSchema.LoggingLevel.CRITICAL, McpSchema.LoggingLevel.ALERT, McpSchema.LoggingLevel.EMERGENCY }; - - // Set minimum level to WARNING - asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - for (McpSchema.LoggingLevel level : levels) { - McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message for " + level) - .build(); - - if (level.level() >= McpSchema.LoggingLevel.WARNING.level()) { - // Should be sent - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) - .thenReturn(Mono.empty()); - - exchange.loggingNotification(notification); - } - else { - // Should be filtered (completes without sending) - exchange.loggingNotification(notification); - } - } - } - // --------------------------------------- // Create Elicitation Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 81d904292..8906adfe0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 154cf3a61..7b77f9241 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index 0381a43bd..97db5fa06 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return new StdioServerTransportProvider(); } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index a71c38493..1e01962e9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return new StdioServerTransportProvider(); } + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java index 2cd62889a..0462cbafe 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerCustomContextPathTests.java @@ -1,6 +1,7 @@ /* * Copyright 2024 - 2024 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java deleted file mode 100644 index b04ecb3c4..000000000 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ /dev/null @@ -1,1390 +0,0 @@ -/* - * Copyright 2024 - 2025 the original author or authors. - */ -package io.modelcontextprotocol.server.transport; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.stream.Collectors; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.server.McpServer; -import io.modelcontextprotocol.server.McpServerFeatures; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; -import io.modelcontextprotocol.spec.McpSchema.ElicitResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import net.javacrumbs.jsonunit.core.Option; - -import org.apache.catalina.LifecycleException; -import org.apache.catalina.LifecycleState; -import org.apache.catalina.startup.Tomcat; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import org.springframework.web.client.RestClient; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.InstanceOfAssertFactories.type; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; - -class HttpServletSseServerTransportProviderIntegrationTests { - - private static final int PORT = TomcatTestUtil.findAvailablePort(); - - private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; - - private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; - - private HttpServletSseServerTransportProvider mcpServerTransportProvider; - - McpClient.SyncSpec clientBuilder; - - private Tomcat tomcat; - - @BeforeEach - public void before() { - // Create and configure the transport provider - mcpServerTransportProvider = HttpServletSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build(); - - tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); - try { - tomcat.start(); - assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); - } - catch (Exception e) { - throw new RuntimeException("Failed to start Tomcat", e); - } - - this.clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_SSE_ENDPOINT) - .build()); - } - - @AfterEach - public void after() { - if (mcpServerTransportProvider != null) { - mcpServerTransportProvider.closeGracefully().block(); - } - if (tomcat != null) { - try { - tomcat.stop(); - tomcat.destroy(); - } - catch (LifecycleException e) { - throw new RuntimeException("Failed to stop Tomcat", e); - } - } - } - - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - // @Disabled - void testCreateMessageWithoutSamplingCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @Test - void testCreateMessageSuccess() { - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @Test - // @Disabled - void testCreateElicitationWithoutElicitationCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createElicitation(mock(ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @Test - void testCreateElicitationSuccess() { - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutSuccess() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutFail() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - - mcpServer.close(); - } - } - - @Test - void testRootsWithoutCapability() { - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @Test - void testRootsNotificationWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithMultipleHandlers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - assertThat(McpTestServletFilter.getThreadLocalValue()).as("blocking code exectuion should be offloaded") - .isNull(); - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @Test - void testToolCallImmediateExecution() { - McpServerFeatures.SyncToolSpecification tool1 = new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var threadLocalValue = McpTestServletFilter.getThreadLocalValue(); - return CallToolResult.builder() - .addTextContent(threadLocalValue != null ? threadLocalValue : "") - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .immediateExecution(true) - .build(); - - try (var mcpClient = clientBuilder.build()) { - mcpClient.initialize(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).first() - .asInstanceOf(type(McpSchema.TextContent.class)) - .extracting(McpSchema.TextContent::text) - .isEqualTo(McpTestServletFilter.THREAD_LOCAL_VALUE); - } - - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @Test - void testInitialize() { - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Logging Tests - // --------------------------------------- - @Test - void testLoggingNotification() { - // Create a list to store received logging notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - // Create server with a tool that sends logging notifications - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("logging-test", "Test logging notifications", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - // Create and send notifications with different levels - - // This should be filtered out (DEBUG < NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .block(); - - // This should be sent (NOTICE >= NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build()) - .block(); - - // This should be sent (ERROR > NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build()) - .block(); - - // This should be filtered out (INFO < NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build()) - .block(); - - // This should be sent (ERROR >= NOTICE) - exchange - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build()) - .block(); - - return Mono.just(new CallToolResult("Logging test completed", false)); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - try ( - // Create client with logging notification handler - var mcpClient = clientBuilder.loggingConsumer(notification -> { - receivedNotifications.add(notification); - }).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Set minimum logging level to NOTICE - mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); - - // Call the tool that sends logging notifications - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); - - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - - System.out.println("Received notifications: " + receivedNotifications); - - // Should have received 3 notifications (1 NOTICE and 2 ERROR) - assertThat(receivedNotifications).hasSize(3); - - Map notificationMap = receivedNotifications.stream() - .collect(Collectors.toMap(n -> n.data(), n -> n)); - - // First notification should be NOTICE level - assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); - assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); - - // Second notification should be ERROR level - assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); - - // Third notification should be ERROR level - assertThat(notificationMap.get("Another error message").level()) - .isEqualTo(McpSchema.LoggingLevel.ERROR); - assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); - assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); - }); - } - mcpServer.close(); - } - - // --------------------------------------- - // Progress Tests - // --------------------------------------- - @Test - void testProgressNotification() { - // Create a list to store received progress notifications - List receivedNotifications = new CopyOnWriteArrayList<>(); - - // Create server with a tool that sends progress notifications - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - McpSchema.Tool.builder() - .name("progress-test") - .description("Test progress notifications") - .inputSchema(emptyJsonSchema) - .build(), - null, (exchange, request) -> { - - var progressToken = request.progressToken(); - - exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.1, 1.0, "Test progress 1/10")) - .block(); - - exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Test progress 5/10")) - .block(); - - exchange - .progressNotification( - new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Test progress 10/10")) - .block(); - - return Mono.just(new CallToolResult("Progress test completed", false)); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().logging().tools(true).build()) - .tools(tool) - .build(); - - // Create client with progress notification handler - try (var mcpClient = clientBuilder.progressConsumer(receivedNotifications::add).build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that sends progress notifications - CallToolResult result = mcpClient.callTool( - new McpSchema.CallToolRequest("progress-test", Map.of(), Map.of("progressToken", "test-token"))); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); - - // Wait for notifications to be processed - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - // Should have received 3 notifications - assertThat(receivedNotifications).hasSize(3); - - // Check the progress notifications - assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progressToken)) - .containsExactlyInAnyOrder("test-token", "test-token", "test-token"); - assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progress)) - .containsExactlyInAnyOrder(0.1, 0.5, 1.0); - }); - } - finally { - mcpServer.close(); - } - } - - // --------------------------------------- - // Ping Tests - // --------------------------------------- - @Test - void testPingSuccess() { - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema), - (exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - @Test - void testStructuredOutputValidationSuccess() { - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - - // Verify structured content (may be null in sync server but validation still - // works) - if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) - .containsEntry("operation", "2 + 3") - .containsEntry("timestamp", "2024-01-01T10:00:00Z"); - } - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputValidationFailure() { - - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputMissingStructuredContent() { - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputRuntimeToolAddition() { - // Start server without tools - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - assertThat(response.structuredContent()).containsEntry("message", "Dynamic execution") - .containsEntry("count", 3); - } - - mcpServer.close(); - } - - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; - } - -} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java index 5a3928e02..2cf95dc94 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/TomcatTestUtil.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.server.transport; import java.io.IOException; diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java b/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java index ba4e851f9..a0bd568ef 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/ArgumentException.java @@ -1,3 +1,7 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + package io.modelcontextprotocol.spec; public class ArgumentException { diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java index 9da31b38b..30158543d 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/DefaultJsonSchemaValidatorTests.java @@ -1,6 +1,7 @@ /* * Copyright 2024-2024 the original author or authors. */ + package io.modelcontextprotocol.spec; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -26,7 +27,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.DefaultJsonSchemaValidator; import io.modelcontextprotocol.spec.JsonSchemaValidator.ValidationResponse; /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java b/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java new file mode 100644 index 000000000..d03a6926d --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/JSONRPCRequestMcpValidationTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for MCP-specific validation of JSONRPCRequest ID requirements. + * + * @author Christian Tzolov + */ +public class JSONRPCRequestMcpValidationTest { + + @Test + public void testValidStringId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", "string-id", null); + assertEquals("string-id", request.id()); + }); + } + + @Test + public void testValidIntegerId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", 123, null); + assertEquals(123, request.id()); + }); + } + + @Test + public void testValidLongId() { + assertDoesNotThrow(() -> { + var request = new McpSchema.JSONRPCRequest("2.0", "test/method", 123L, null); + assertEquals(123L, request.id()); + }); + } + + @Test + public void testNullIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", null, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST include an ID")); + assertTrue(exception.getMessage().contains("null IDs are not allowed")); + } + + @Test + public void testDoubleIdTypeThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", 123.45, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testBooleanIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", true, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testArrayIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", new String[] { "array" }, null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + + @Test + public void testObjectIdThrowsException() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> { + new McpSchema.JSONRPCRequest("2.0", "test/method", new Object(), null); + }); + + assertTrue(exception.getMessage().contains("MCP requests MUST have an ID that is either a string or integer")); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index fbbb4307e..a5b2137fd 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -1,6 +1,7 @@ /* * Copyright 2025 - 2025 the original author or authors. */ + package io.modelcontextprotocol.spec; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; @@ -319,8 +320,8 @@ void testInitializeRequest() throws Exception { McpSchema.Implementation clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); Map meta = Map.of("metaKey", "metaValue"); - McpSchema.InitializeRequest request = new McpSchema.InitializeRequest("2024-11-05", capabilities, clientInfo, - meta); + McpSchema.InitializeRequest request = new McpSchema.InitializeRequest(ProtocolVersions.MCP_2024_11_05, + capabilities, clientInfo, meta); String value = mapper.writeValueAsString(request); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) @@ -342,8 +343,8 @@ void testInitializeResult() throws Exception { McpSchema.Implementation serverInfo = new McpSchema.Implementation("test-server", "1.0.0"); - McpSchema.InitializeResult result = new McpSchema.InitializeResult("2024-11-05", capabilities, serverInfo, - "Server initialized successfully"); + McpSchema.InitializeResult result = new McpSchema.InitializeResult(ProtocolVersions.MCP_2024_11_05, + capabilities, serverInfo, "Server initialized successfully"); String value = mapper.writeValueAsString(result); assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java new file mode 100644 index 000000000..4de9363c2 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/util/KeepAliveSchedulerTests.java @@ -0,0 +1,303 @@ +/* + * Copyright 2025-2025 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import com.fasterxml.jackson.core.type.TypeReference; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSession; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.scheduler.VirtualTimeScheduler; + +/** + * Unit tests for {@link KeepAliveScheduler}. + * + * @author Christian Tzolov + */ +class KeepAliveSchedulerTests { + + private MockMcpSession mockSession1; + + private MockMcpSession mockSession2; + + private Supplier> mockSessionsSupplier; + + private VirtualTimeScheduler virtualTimeScheduler; + + @BeforeEach + void setUp() { + virtualTimeScheduler = VirtualTimeScheduler.create(); + mockSession1 = new MockMcpSession(); + mockSession2 = new MockMcpSession(); + mockSessionsSupplier = () -> Flux.just(mockSession1); + } + + @AfterEach + void tearDown() { + if (virtualTimeScheduler != null) { + virtualTimeScheduler.dispose(); + } + } + + @Test + void testBuilderWithNullSessionsSupplier() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("McpSessions supplier must not be null"); + } + + @Test + void testBuilderWithNullScheduler() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).scheduler(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Scheduler must not be null"); + } + + @Test + void testBuilderWithNullInitialDelay() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).initialDelay(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Initial delay must not be null"); + } + + @Test + void testBuilderWithNullInterval() { + assertThatThrownBy(() -> KeepAliveScheduler.builder(mockSessionsSupplier).interval(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Interval must not be null"); + } + + @Test + void testBuilderDefaults() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier).build(); + + assertThat(scheduler).isNotNull(); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testStartWithMultipleSessions() { + mockSessionsSupplier = () -> Flux.just(mockSession1, mockSession2); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + assertThat(scheduler.isRunning()).isFalse(); + + // Start the scheduler + Disposable disposable = scheduler.start(); + + assertThat(scheduler.isRunning()).isTrue(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + + // Advance time to trigger the first ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify both sessions received ping + assertThat(mockSession1.getPingCount()).isEqualTo(1); + assertThat(mockSession2.getPingCount()).isEqualTo(1); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Second ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Third ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(2)); // Fourth ping + + // Verify second ping was sent + assertThat(mockSession1.getPingCount()).isEqualTo(4); + assertThat(mockSession2.getPingCount()).isEqualTo(4); + + // Clean up + scheduler.stop(); + + assertThat(scheduler.isRunning()).isFalse(); + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isTrue(); + } + + @Test + void testStartWithEmptySessionsList() { + mockSessionsSupplier = () -> Flux.empty(); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger ping attempts + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify no sessions were called (since list was empty) + assertThat(mockSession1.getPingCount()).isEqualTo(0); + assertThat(mockSession2.getPingCount()).isEqualTo(0); + + // Clean up + scheduler.stop(); + } + + @Test + void testStartWhenAlreadyRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + + // Try to start again - should throw exception + assertThatThrownBy(scheduler::start).isInstanceOf(IllegalStateException.class) + .hasMessage("KeepAlive scheduler is already running. Stop it first."); + + // Clean up + scheduler.stop(); + } + + @Test + void testStopWhenNotRunning() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Should not throw exception when stopping a non-running scheduler + assertDoesNotThrow(scheduler::stop); + assertThat(scheduler.isRunning()).isFalse(); + } + + @Test + void testShutdown() { + // Setup with a separate virtual time scheduler (which is disposable) + VirtualTimeScheduler separateScheduler = VirtualTimeScheduler.create(); + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(separateScheduler) + .build(); + + // Start the scheduler + scheduler.start(); + assertThat(scheduler.isRunning()).isTrue(); + + // Shutdown should stop the scheduler and dispose the scheduler + scheduler.shutdown(); + assertThat(scheduler.isRunning()).isFalse(); + assertThat(separateScheduler.isDisposed()).isTrue(); + } + + @Test + void testPingFailureHandling() { + // Setup session that fails ping + mockSession1.setShouldFailPing(true); + + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .initialDelay(Duration.ofSeconds(1)) + .interval(Duration.ofSeconds(2)) + .build(); + + // Start the scheduler + scheduler.start(); + + // Advance time to trigger the ping + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(1)); + + // Verify ping was attempted (error should be handled gracefully) + assertThat(mockSession1.getPingCount()).isEqualTo(1); + + // Scheduler should still be running despite the error + assertThat(scheduler.isRunning()).isTrue(); + + // Clean up + scheduler.stop(); + } + + @Test + void testDisposableReturnedFromStart() { + KeepAliveScheduler scheduler = KeepAliveScheduler.builder(mockSessionsSupplier) + .scheduler(virtualTimeScheduler) + .build(); + + // Start and get disposable + Disposable disposable = scheduler.start(); + + assertThat(disposable).isNotNull(); + assertThat(disposable.isDisposed()).isFalse(); + assertThat(scheduler.isRunning()).isTrue(); + + // Dispose directly through the returned disposable + disposable.dispose(); + + assertThat(disposable.isDisposed()).isTrue(); + assertThat(scheduler.isRunning()).isFalse(); + } + + /** + * Simple mock implementation of McpSession for testing purposes. + */ + private static class MockMcpSession implements McpSession { + + private final AtomicInteger pingCount = new AtomicInteger(0); + + private boolean shouldFailPing = false; + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + if (McpSchema.METHOD_PING.equals(method)) { + pingCount.incrementAndGet(); + if (shouldFailPing) { + return Mono.error(new RuntimeException("Connection failed")); + } + return Mono.just((T) new Object()); + } + return Mono.empty(); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.empty(); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + // No-op for mock + } + + public int getPingCount() { + return pingCount.get(); + } + + public void setShouldFailPing(boolean shouldFailPing) { + this.shouldFailPing = shouldFailPing; + } + + @Override + public String toString() { + return "MockMcpSession"; + } + + } + +} diff --git a/pom.xml b/pom.xml index b7a66aeec..c0b1f7a44 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ io.modelcontextprotocol.sdk mcp-parent - 0.11.0-SNAPSHOT + 0.12.0-SNAPSHOT pom https://github.com/modelcontextprotocol/java-sdk