Skip to content

WIP: AuthContext #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp-bom/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<parent>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-parent</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
</parent>

<artifactId>mcp-bom</artifactId>
Expand Down
6 changes: 3 additions & 3 deletions mcp-spring/mcp-spring-webflux/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-parent</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>mcp-spring-webflux</artifactId>
Expand All @@ -25,13 +25,13 @@
<dependency>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
</dependency>

<dependency>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-test</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
<scope>test</scope>
</dependency>

Expand Down
6 changes: 3 additions & 3 deletions mcp-spring/mcp-spring-webmvc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-parent</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>mcp-spring-webmvc</artifactId>
Expand All @@ -25,13 +25,13 @@
<dependency>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
</dependency>

<dependency>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-test</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
<scope>test</scope>
</dependency>

Expand Down
4 changes: 2 additions & 2 deletions mcp-test/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-parent</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
</parent>
<artifactId>mcp-test</artifactId>
<packaging>jar</packaging>
Expand All @@ -24,7 +24,7 @@
<dependency>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
</dependency>

<dependency>
Expand Down
2 changes: 1 addition & 1 deletion mcp/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-parent</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>
</parent>
<artifactId>mcp</artifactId>
<packaging>jar</packaging>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.auth.SecurityContext;
import io.modelcontextprotocol.spec.McpClientSession;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
Expand Down Expand Up @@ -183,9 +184,9 @@ public class McpAsyncServer {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED,
asyncRootsListChangedNotificationHandler(rootsChangeConsumers));

mcpTransportProvider.setSessionFactory(
transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport,
this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers));
mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(),
requestTimeout, transport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers,
notificationHandlers, SecurityContext.EMPTY));
}

// ---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.modelcontextprotocol.server;

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.server.auth.SecurityContext;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
Expand All @@ -28,6 +29,8 @@ public class McpAsyncServerExchange {

private final McpSchema.Implementation clientInfo;

private final SecurityContext securityContext;

private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO;

private static final TypeReference<McpSchema.CreateMessageResult> CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() {
Expand All @@ -47,10 +50,11 @@ public class McpAsyncServerExchange {
* @param clientInfo The client implementation information.
*/
public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities,
McpSchema.Implementation clientInfo) {
McpSchema.Implementation clientInfo, SecurityContext securityContext) {
this.session = session;
this.clientCapabilities = clientCapabilities;
this.clientInfo = clientInfo;
this.securityContext = securityContext;
}

/**
Expand Down Expand Up @@ -159,6 +163,11 @@ public Mono<Void> loggingNotification(LoggingMessageNotification loggingMessageN
});
}

public Mono<SecurityContext> getSecurityContext() {
// defer()? Could securityContext change over time, e.g. token refreshes?
return Mono.just(securityContext == null ? SecurityContext.EMPTY : securityContext);
}

/**
* Set the minimum logging level for the client. Messages below this level will be
* filtered out.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

package io.modelcontextprotocol.server;

import io.modelcontextprotocol.server.auth.SecurityContext;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;

/**
Expand Down Expand Up @@ -108,4 +108,8 @@ public void loggingNotification(LoggingMessageNotification loggingMessageNotific
this.exchange.loggingNotification(loggingMessageNotification).block();
}

public SecurityContext getSecurityContext() {
return this.exchange.getSecurityContext().block();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package io.modelcontextprotocol.server.auth;

import java.security.Principal;

public record SecurityContext(Principal principal, String authHeader) {
// absent SecurityContext marker
public static final SecurityContext EMPTY = new SecurityContext(null, "");

public boolean isEmpty() {
return this == EMPTY;
}

public boolean isPresent() {
return !isEmpty();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.server.auth.SecurityContext;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -191,6 +192,14 @@ public Mono<Void> notifyClients(String method, Object params) {
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {

SecurityContext securityContext = (SecurityContext) request.getAttribute("securityContext");
if (securityContext == null) {
// if null but auth is not required...
// TODO: and auth is required...
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized");
return;
}

String requestURI = request.getRequestURI();
if (!requestURI.endsWith(sseEndpoint)) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
Expand Down Expand Up @@ -220,6 +229,11 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)

// Create a new session using the session factory
McpServerSession session = sessionFactory.create(sessionTransport);

// set security context if available (otherwise it'll be SecurityContext.EMPTY)
if (securityContext != null) {
session.setSecurityContext(securityContext);
}
this.sessions.put(sessionId, session);

// Send initial endpoint event
Expand All @@ -246,6 +260,13 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}

SecurityContext securityContext = (SecurityContext) request.getAttribute("securityContext");
if (securityContext == null) {
// TODO: and auth is required...
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Unauthorized");
return;
}

String requestURI = request.getRequestURI();
if (!requestURI.endsWith(messageEndpoint)) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
Expand Down Expand Up @@ -278,6 +299,19 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}

// set security context if available (otherwise it'll be SecurityContext.EMPTY)
if (securityContext != null) {
// Update the session's security context
if (session.getSecurityContext() == null) {
// TODO should we allow this if no previous security context?
}
else if (!session.getSecurityContext().principal().equals(securityContext.principal())) {
// TODO if the principal has changed, return unauthorized
// don't allow changing the principal for the session
}
session.setSecurityContext(securityContext);
}

try {
BufferedReader reader = request.getReader();
StringBuilder body = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.fasterxml.jackson.core.type.TypeReference;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.auth.SecurityContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -48,6 +49,8 @@ public class McpServerSession implements McpSession {

private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference<>();

private SecurityContext securityContext;

private static final int STATE_UNINITIALIZED = 0;

private static final int STATE_INITIALIZING = 1;
Expand All @@ -68,17 +71,20 @@ public class McpServerSession implements McpSession {
* received.
* @param requestHandlers map of request handlers to use
* @param notificationHandlers map of notification handlers to use
* @param securityContext the authentication context for this session
*/
public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport,
InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler,
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers) {
Map<String, RequestHandler<?>> requestHandlers, Map<String, NotificationHandler> notificationHandlers,
SecurityContext securityContext) {
this.id = id;
this.requestTimeout = requestTimeout;
this.transport = transport;
this.initRequestHandler = initHandler;
this.initNotificationHandler = initNotificationHandler;
this.requestHandlers = requestHandlers;
this.notificationHandlers = notificationHandlers;
this.securityContext = securityContext != null ? securityContext : SecurityContext.EMPTY;
}

/**
Expand Down Expand Up @@ -242,7 +248,8 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
return Mono.defer(() -> {
if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) {
this.state.lazySet(STATE_INITIALIZED);
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get()));
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get(),
this.securityContext));
return this.initNotificationHandler.handle();
}

Expand All @@ -255,6 +262,14 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
});
}

public SecurityContext getSecurityContext() {
return securityContext;
}

public void setSecurityContext(SecurityContext securityContext) {
this.securityContext = securityContext != null ? securityContext : SecurityContext.EMPTY;
}

record MethodNotFoundError(String method, String message, Object data) {
}

Expand Down Expand Up @@ -321,6 +336,7 @@ public interface NotificationHandler {
* @param <T> the type of the response that is expected as a result of handling the
* request.
*/
@FunctionalInterface
public interface RequestHandler<T> {

/**
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>io.modelcontextprotocol.sdk</groupId>
<artifactId>mcp-parent</artifactId>
<version>0.11.0-SNAPSHOT</version>
<version>0.11.1-mcp-tool-level-authnz-SNAPSHOT</version>

<packaging>pom</packaging>
<url>https://github.com/modelcontextprotocol/java-sdk</url>
Expand Down