Skip to content

feat: Add GetHeader() Support for SSE Transport Sessions (ClientSession) #444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"sort"
"testing"
Expand Down Expand Up @@ -1522,6 +1523,11 @@ func (f fakeSession) Initialized() bool {
return f.initialized
}

func (f fakeSession) GetHeader() http.Header {
// Test session doesn't have HTTP headers, return empty header
return make(http.Header)
}

var _ ClientSession = fakeSession{}

func TestMCPServer_WithHooks(t *testing.T) {
Expand Down
3 changes: 3 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
"net/http"

"github.com/mark3labs/mcp-go/mcp"
)
Expand All @@ -17,6 +18,8 @@ type ClientSession interface {
NotificationChannel() chan<- mcp.JSONRPCNotification
// SessionID is a unique identifier used to track user session.
SessionID() string
// GetHeader returns the HTTP headers from the initial request
GetHeader() http.Header
}

// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level
Expand Down
21 changes: 21 additions & 0 deletions server/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"net/http"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -42,6 +43,11 @@ func (f sessionTestClient) Initialized() bool {
return f.initialized
}

func (f sessionTestClient) GetHeader() http.Header {
// Test session doesn't have HTTP headers, return empty header
return make(http.Header)
}

// sessionTestClientWithTools implements the SessionWithTools interface for testing
type sessionTestClientWithTools struct {
sessionID string
Expand Down Expand Up @@ -100,6 +106,11 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
f.sessionTools = toolsCopy
}

func (f *sessionTestClientWithTools) GetHeader() http.Header {
// Test session doesn't have HTTP headers, return empty header
return make(http.Header)
}

// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
type sessionTestClientWithClientInfo struct {
sessionID string
Expand Down Expand Up @@ -137,6 +148,11 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement
f.clientInfo.Store(clientInfo)
}

func (f *sessionTestClientWithClientInfo) GetHeader() http.Header {
// Test session doesn't have HTTP headers, return empty header
return make(http.Header)
}

// sessionTestClientWithTools implements the SessionWithLogging interface for testing
type sessionTestClientWithLogging struct {
sessionID string
Expand Down Expand Up @@ -172,6 +188,11 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
return level.(mcp.LoggingLevel)
}

func (f *sessionTestClientWithLogging) GetHeader() http.Header {
// Test session doesn't have HTTP headers, return empty header
return make(http.Header)
}

// Verify that all implementations satisfy their respective interfaces
var (
_ ClientSession = (*sessionTestClient)(nil)
Expand Down
6 changes: 6 additions & 0 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type sseSession struct {
loggingLevel atomic.Value
tools sync.Map // stores session-specific tools
clientInfo atomic.Value // stores session-specific client info
headers http.Header // stores HTTP headers from the initial request
}

// SSEContextFunc is a function that takes an existing context and the current
Expand Down Expand Up @@ -108,6 +109,10 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
s.clientInfo.Store(clientInfo)
}

func (s *sseSession) GetHeader() http.Header {
return s.headers
}

var (
_ ClientSession = (*sseSession)(nil)
_ SessionWithTools = (*sseSession)(nil)
Expand Down Expand Up @@ -353,6 +358,7 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
eventQueue: make(chan string, 100), // Buffer for events
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
headers: r.Header.Clone(), // Store HTTP headers from request
}

s.sessions.Store(sessionID, session)
Expand Down
6 changes: 6 additions & 0 deletions server/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"sync/atomic"
Expand Down Expand Up @@ -100,6 +101,11 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
return level.(mcp.LoggingLevel)
}

func (s *stdioSession) GetHeader() http.Header {
// stdio transport doesn't have HTTP headers, return empty header
return make(http.Header)
}

var (
_ ClientSession = (*stdioSession)(nil)
_ SessionWithLogging = (*stdioSession)(nil)
Expand Down
17 changes: 16 additions & 1 deletion server/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request)
sessionID = uuid.New().String()
}

session := newStreamableHttpSession(sessionID, s.sessionTools)
session := newStreamableHttpSessionWithHeaders(sessionID, s.sessionTools, r.Header)
if err := s.server.RegisterSession(r.Context(), session); err != nil {
http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
return
Expand Down Expand Up @@ -549,13 +549,24 @@ type streamableHttpSession struct {
notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
tools *sessionToolsStore
upgradeToSSE atomic.Bool
headers http.Header // stores HTTP headers from the initial request
}

func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
return &streamableHttpSession{
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
headers: make(http.Header), // Initialize empty headers for ephemeral sessions
}
}

func newStreamableHttpSessionWithHeaders(sessionID string, toolStore *sessionToolsStore, headers http.Header) *streamableHttpSession {
return &streamableHttpSession{
sessionID: sessionID,
notificationChannel: make(chan mcp.JSONRPCNotification, 100),
tools: toolStore,
headers: headers.Clone(),
}
}

Expand Down Expand Up @@ -595,6 +606,10 @@ func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {

var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)

func (s *streamableHttpSession) GetHeader() http.Header {
return s.headers
}

// --- session id manager ---

type SessionIdManager interface {
Expand Down