diff --git a/server/server_test.go b/server/server_test.go index 1c81d18dd..40c2b8f2a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "reflect" "sort" "testing" @@ -1522,6 +1523,11 @@ func (f fakeSession) Initialized() bool { return f.initialized } +func (f fakeSession) GetHeader() http.Header { + // Test session doesn't have HTTP headers, return empty header + return make(http.Header) +} + var _ ClientSession = fakeSession{} func TestMCPServer_WithHooks(t *testing.T) { diff --git a/server/session.go b/server/session.go index a79da22ca..1d05d6c04 100644 --- a/server/session.go +++ b/server/session.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "net/http" "github.com/mark3labs/mcp-go/mcp" ) @@ -17,6 +18,8 @@ type ClientSession interface { NotificationChannel() chan<- mcp.JSONRPCNotification // SessionID is a unique identifier used to track user session. SessionID() string + // GetHeader returns the HTTP headers from the initial request + GetHeader() http.Header } // SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9c..c299bc10c 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "errors" + "net/http" "sync" "sync/atomic" "testing" @@ -42,6 +43,11 @@ func (f sessionTestClient) Initialized() bool { return f.initialized } +func (f sessionTestClient) GetHeader() http.Header { + // Test session doesn't have HTTP headers, return empty header + return make(http.Header) +} + // sessionTestClientWithTools implements the SessionWithTools interface for testing type sessionTestClientWithTools struct { sessionID string @@ -100,6 +106,11 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool f.sessionTools = toolsCopy } +func (f *sessionTestClientWithTools) GetHeader() http.Header { + // Test session doesn't have HTTP headers, return empty header + return make(http.Header) +} + // sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing type sessionTestClientWithClientInfo struct { sessionID string @@ -137,6 +148,11 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } +func (f *sessionTestClientWithClientInfo) GetHeader() http.Header { + // Test session doesn't have HTTP headers, return empty header + return make(http.Header) +} + // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string @@ -172,6 +188,11 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (f *sessionTestClientWithLogging) GetHeader() http.Header { + // Test session doesn't have HTTP headers, return empty header + return make(http.Header) +} + // Verify that all implementations satisfy their respective interfaces var ( _ ClientSession = (*sessionTestClient)(nil) diff --git a/server/sse.go b/server/sse.go index 416995730..6497199ec 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,6 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools clientInfo atomic.Value // stores session-specific client info + headers http.Header // stores HTTP headers from the initial request } // SSEContextFunc is a function that takes an existing context and the current @@ -108,6 +109,10 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *sseSession) GetHeader() http.Header { + return s.headers +} + var ( _ ClientSession = (*sseSession)(nil) _ SessionWithTools = (*sseSession)(nil) @@ -353,6 +358,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { eventQueue: make(chan string, 100), // Buffer for events sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), + headers: r.Header.Clone(), // Store HTTP headers from request } s.sessions.Store(sessionID, session) diff --git a/server/stdio.go b/server/stdio.go index 746a7d96f..e3fd2cc45 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log" + "net/http" "os" "os/signal" "sync/atomic" @@ -100,6 +101,11 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +func (s *stdioSession) GetHeader() http.Header { + // stdio transport doesn't have HTTP headers, return empty header + return make(http.Header) +} + var ( _ ClientSession = (*stdioSession)(nil) _ SessionWithLogging = (*stdioSession)(nil) diff --git a/server/streamable_http.go b/server/streamable_http.go index 1312c9753..35598a56f 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -365,7 +365,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSessionWithHeaders(sessionID, s.sessionTools, r.Header) if err := s.server.RegisterSession(r.Context(), session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) return @@ -549,6 +549,7 @@ type streamableHttpSession struct { notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore upgradeToSSE atomic.Bool + headers http.Header // stores HTTP headers from the initial request } func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { @@ -556,6 +557,16 @@ func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *s sessionID: sessionID, notificationChannel: make(chan mcp.JSONRPCNotification, 100), tools: toolStore, + headers: make(http.Header), // Initialize empty headers for ephemeral sessions + } +} + +func newStreamableHttpSessionWithHeaders(sessionID string, toolStore *sessionToolsStore, headers http.Header) *streamableHttpSession { + return &streamableHttpSession{ + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + headers: headers.Clone(), } } @@ -595,6 +606,10 @@ func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) +func (s *streamableHttpSession) GetHeader() http.Header { + return s.headers +} + // --- session id manager --- type SessionIdManager interface {