Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace ModelContextProtocol.AspNetCore;

internal sealed class StreamableHttpHandler(
internal sealed partial class StreamableHttpHandler(
IOptions<McpServerOptions> mcpServerOptionsSnapshot,
IOptionsFactory<McpServerOptions> mcpServerOptionsFactory,
IOptions<HttpServerTransportOptions> httpServerTransportOptions,
Expand All @@ -28,6 +28,20 @@ internal sealed class StreamableHttpHandler(
{
private const string McpSessionIdHeaderName = "Mcp-Session-Id";
private static readonly JsonTypeInfo<JsonRpcError> s_errorTypeInfo = GetRequiredJsonTypeInfo<JsonRpcError>();

private readonly ILogger _logger = loggerFactory.CreateLogger<StreamableHttpHandler>();

// Headers that are safe and relevant to log for MCP over HTTP
private static readonly HashSet<string> SafeHeadersToLog = new(StringComparer.OrdinalIgnoreCase)
{
"Accept",
"Content-Type",
"Content-Length",
"User-Agent",
McpSessionIdHeaderName,
"X-Request-ID",
"X-Correlation-ID"
};

public ConcurrentDictionary<string, HttpMcpSession<StreamableHttpServerTransport>> Sessions { get; } = new(StringComparer.Ordinal);

Expand All @@ -37,6 +51,9 @@ internal sealed class StreamableHttpHandler(

public async Task HandlePostRequestAsync(HttpContext context)
{
// Log request headers for debugging
LogHttpRequestHeadersIfEnabled(context);

// The Streamable HTTP spec mandates the client MUST accept both application/json and text/event-stream.
// ASP.NET Core Minimal APIs mostly try to stay out of the business of response content negotiation,
// so we have to do this manually. The spec doesn't mandate that servers MUST reject these requests,
Expand Down Expand Up @@ -83,6 +100,9 @@ await WriteJsonRpcErrorAsync(context,

public async Task HandleGetRequestAsync(HttpContext context)
{
// Log request headers for debugging
LogHttpRequestHeadersIfEnabled(context);

if (!context.Request.GetTypedHeaders().Accept.Any(MatchesTextEventStreamMediaType))
{
await WriteJsonRpcErrorAsync(context,
Expand Down Expand Up @@ -118,6 +138,9 @@ await WriteJsonRpcErrorAsync(context,

public async Task HandleDeleteRequestAsync(HttpContext context)
{
// Log request headers for debugging
LogHttpRequestHeadersIfEnabled(context);

var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString();
if (Sessions.TryRemove(sessionId, out var session))
{
Expand Down Expand Up @@ -336,6 +359,27 @@ private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptH
private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue)
=> acceptHeaderValue.MatchesMediaType("text/event-stream");

private void LogHttpRequestHeadersIfEnabled(HttpContext context)
{
if (_logger.IsEnabled(LogLevel.Trace))
{
var safeHeaders = new Dictionary<string, string>();

foreach (var header in context.Request.Headers)
{
if (SafeHeadersToLog.Contains(header.Key))
{
safeHeaders[header.Key] = header.Value.ToString();
}
}

LogHttpRequestHeaders(context.Request.Method, context.Request.Path, safeHeaders);
}
}

[LoggerMessage(Level = LogLevel.Trace, Message = "HTTP {Method} {Path} - Headers: {Headers}")]
private partial void LogHttpRequestHeaders(string method, string path, Dictionary<string, string> headers);

private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe
{
public PipeReader Input => context.Request.BodyReader;
Expand Down
16 changes: 14 additions & 2 deletions src/ModelContextProtocol.Core/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public McpSession(
_requestHandlers = requestHandlers;
_notificationHandlers = notificationHandlers;
_logger = logger ?? NullLogger.Instance;
LogSessionCreated(EndpointName, _sessionId, _transportKind);
}

/// <summary>
Expand All @@ -108,6 +109,7 @@ public McpSession(
/// </summary>
public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
{
LogSessionConnected(EndpointName, _sessionId, _transportKind);
try
{
await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false))
Expand Down Expand Up @@ -609,9 +611,9 @@ private static void AddExceptionTags(ref TagList tags, Activity? activity, Excep
e = ae.InnerException;
}

int? intErrorCode =
int? intErrorCode =
(int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode :
e is JsonException ? (int)McpErrorCode.ParseError :
e is JsonException ? (int)McpErrorCode.ParseError :
null;

string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName;
Expand Down Expand Up @@ -692,6 +694,7 @@ public void Dispose()
}

_pendingRequests.Clear();
LogSessionDisposed(EndpointName, _sessionId, _transportKind);
}

#if !NET
Expand Down Expand Up @@ -774,4 +777,13 @@ private static TimeSpan GetElapsed(long startingTimestamp) =>

[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")]
private partial void LogSendingMessageSensitive(string endpointName, string message);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} session {SessionId} created with transport {TransportKind}")]
private partial void LogSessionCreated(string endpointName, string sessionId, string transportKind);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} session {SessionId} connected and processing messages with transport {TransportKind}")]
private partial void LogSessionConnected(string endpointName, string sessionId, string transportKind);

[LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} session {SessionId} disposed with transport {TransportKind}")]
private partial void LogSessionDisposed(string endpointName, string sessionId, string transportKind);
}