Skip to content

Commit a8a2ec8

Browse files
committed
mcp: add StreamableHTTPOptions.SessionTimeout
Add a timeout option for the streamable HTTP handler that automatically cleans up idle sessions. Also, fix a bug in the streamable client, where we hang on a request even though the client can never get a response (because the HTTP request terminated without a response or Last-Event-Id). Fixes #499
1 parent 8fe64fc commit a8a2ec8

File tree

3 files changed

+187
-43
lines changed

3 files changed

+187
-43
lines changed

mcp/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ func (s *Server) disconnect(cc *ServerSession) {
825825
type ServerSessionOptions struct {
826826
State *ServerSessionState
827827

828-
onClose func()
828+
onClose func() // used to clean up associated resources
829829
}
830830

831831
// Connect connects the MCP server over the given transport and starts handling

mcp/streamable.go

Lines changed: 121 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"fmt"
1313
"io"
1414
"log/slog"
15+
"maps"
1516
"math"
1617
"math/rand/v2"
1718
"net/http"
@@ -40,12 +41,46 @@ type StreamableHTTPHandler struct {
4041
getServer func(*http.Request) *Server
4142
opts StreamableHTTPOptions
4243

43-
onTransportDeletion func(sessionID string) // for testing only
44+
onTransportDeletion func(sessionID string) // for testing
4445

45-
mu sync.Mutex
46-
// TODO: we should store the ServerSession along with the transport, because
47-
// we need to cancel keepalive requests when closing the transport.
48-
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
46+
mu sync.Mutex
47+
sessions map[string]*sessionInfo // keyed by session ID
48+
}
49+
50+
type sessionInfo struct {
51+
session *ServerSession
52+
transport *StreamableServerTransport
53+
54+
// If timeout is set, automatically close the session after an idle period.
55+
timeout time.Duration
56+
timerMu sync.Mutex
57+
timer *time.Timer
58+
}
59+
60+
// resetTimeout resets the inactivity timer.
61+
func (i *sessionInfo) resetTimeout() {
62+
if i.timeout <= 0 {
63+
return
64+
}
65+
66+
i.timerMu.Lock()
67+
defer i.timerMu.Unlock()
68+
69+
if i.timer == nil {
70+
return
71+
}
72+
// Reset the timer if we successfully stopped it.
73+
i.timer.Reset(i.timeout)
74+
}
75+
76+
// stopTimer stops the inactivity timer.
77+
func (i *sessionInfo) stopTimer() {
78+
i.timerMu.Lock()
79+
defer i.timerMu.Unlock()
80+
if i.timer != nil {
81+
i.timer.Stop()
82+
i.timer = nil
83+
}
4984
}
5085

5186
// StreamableHTTPOptions configures the StreamableHTTPHandler.
@@ -72,6 +107,14 @@ type StreamableHTTPOptions struct {
72107
// If nil, do not log.
73108
Logger *slog.Logger
74109

110+
// SessionTimeout configures a timeout for idle sessions.
111+
//
112+
// When sessions receive no new HTTP requests from the client for this
113+
// duration, they are automatically closed.
114+
//
115+
// If SessionTimeout is the zero value, idle sessions are never closed.
116+
SessionTimeout time.Duration
117+
75118
// TODO(rfindley): file a proposal to export this option, or something equivalent.
76119
configureTransport func(req *http.Request, transport *StreamableServerTransport)
77120
}
@@ -83,8 +126,8 @@ type StreamableHTTPOptions struct {
83126
// If getServer returns nil, a 400 Bad Request will be served.
84127
func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *StreamableHTTPOptions) *StreamableHTTPHandler {
85128
h := &StreamableHTTPHandler{
86-
getServer: getServer,
87-
transports: make(map[string]*StreamableServerTransport),
129+
getServer: getServer,
130+
sessions: make(map[string]*sessionInfo),
88131
}
89132
if opts != nil {
90133
h.opts = *opts
@@ -97,20 +140,27 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
97140
return h
98141
}
99142

100-
// closeAll closes all ongoing sessions.
143+
// closeAll closes all ongoing sessions, for tests.
101144
//
102145
// TODO(rfindley): investigate the best API for callers to configure their
103146
// session lifecycle. (?)
104147
//
105148
// Should we allow passing in a session store? That would allow the handler to
106149
// be stateless.
107150
func (h *StreamableHTTPHandler) closeAll() {
151+
// TODO: if we ever expose this outside of tests, we'll need to do better
152+
// than simply collecting sessions while holding the lock: we need to prevent
153+
// new sessions from being added.
154+
//
155+
// Currently, sessions remove themselves from h.sessions when closed, so we
156+
// can't call Close while holding the lock.
108157
h.mu.Lock()
109-
defer h.mu.Unlock()
110-
for _, s := range h.transports {
111-
s.connection.Close()
158+
sessionInfos := slices.Collect(maps.Values(h.sessions))
159+
h.sessions = nil
160+
h.mu.Unlock()
161+
for _, s := range sessionInfos {
162+
s.session.Close()
112163
}
113-
h.transports = nil
114164
}
115165

116166
func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@@ -141,12 +191,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
141191
}
142192

143193
sessionID := req.Header.Get(sessionIDHeader)
144-
var transport *StreamableServerTransport
194+
var sessInfo *sessionInfo
145195
if sessionID != "" {
146196
h.mu.Lock()
147-
transport = h.transports[sessionID]
197+
sessInfo = h.sessions[sessionID]
148198
h.mu.Unlock()
149-
if transport == nil && !h.opts.Stateless {
199+
if sessInfo == nil && !h.opts.Stateless {
150200
// Unless we're in 'stateless' mode, which doesn't perform any Session-ID
151201
// validation, we require that the session ID matches a known session.
152202
//
@@ -161,11 +211,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
161211
http.Error(w, "Bad Request: DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
162212
return
163213
}
164-
if transport != nil { // transport may be nil in stateless mode
165-
h.mu.Lock()
166-
delete(h.transports, transport.SessionID)
167-
h.mu.Unlock()
168-
transport.connection.Close()
214+
if sessInfo != nil { // sessInfo may be nil in stateless mode
215+
// Closing the session also removes it from h.sessions, due to the
216+
// onClose callback.
217+
sessInfo.session.Close()
169218
}
170219
w.WriteHeader(http.StatusNoContent)
171220
return
@@ -222,7 +271,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
222271
return
223272
}
224273

225-
if transport == nil {
274+
if sessInfo == nil {
226275
server := h.getServer(req)
227276
if server == nil {
228277
// The getServer argument to NewStreamableHTTPHandler returned nil.
@@ -234,7 +283,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
234283
// existing transport.
235284
sessionID = server.opts.GetSessionID()
236285
}
237-
transport = &StreamableServerTransport{
286+
transport := &StreamableServerTransport{
238287
SessionID: sessionID,
239288
Stateless: h.opts.Stateless,
240289
jsonResponse: h.opts.JSONResponse,
@@ -300,10 +349,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
300349
connectOpts = &ServerSessionOptions{
301350
onClose: func() {
302351
h.mu.Lock()
303-
delete(h.transports, transport.SessionID)
304-
h.mu.Unlock()
305-
if h.onTransportDeletion != nil {
306-
h.onTransportDeletion(transport.SessionID)
352+
defer h.mu.Unlock()
353+
if info, ok := h.sessions[transport.SessionID]; ok {
354+
info.stopTimer()
355+
delete(h.sessions, transport.SessionID)
356+
if h.onTransportDeletion != nil {
357+
h.onTransportDeletion(transport.SessionID)
358+
}
307359
}
308360
},
309361
}
@@ -312,23 +364,36 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
312364
// Pass req.Context() here, to allow middleware to add context values.
313365
// The context is detached in the jsonrpc2 library when handling the
314366
// long-running stream.
315-
ss, err := server.Connect(req.Context(), transport, connectOpts)
367+
session, err := server.Connect(req.Context(), transport, connectOpts)
316368
if err != nil {
317369
http.Error(w, "failed connection", http.StatusInternalServerError)
318370
return
319371
}
372+
sessInfo = &sessionInfo{
373+
session: session,
374+
transport: transport,
375+
}
320376
if h.opts.Stateless {
321377
// Stateless mode: close the session when the request exits.
322-
defer ss.Close() // close the fake session after handling the request
378+
defer session.Close() // close the fake session after handling the request
323379
} else {
324380
// Otherwise, save the transport so that it can be reused
381+
382+
// Clean up the session when it times out.
383+
if h.opts.SessionTimeout > 0 {
384+
sessInfo.timeout = h.opts.SessionTimeout
385+
sessInfo.timer = time.AfterFunc(sessInfo.timeout, func() {
386+
sessInfo.session.Close()
387+
})
388+
}
325389
h.mu.Lock()
326-
h.transports[transport.SessionID] = transport
390+
h.sessions[transport.SessionID] = sessInfo
327391
h.mu.Unlock()
328392
}
329393
}
330394

331-
transport.ServeHTTP(w, req)
395+
sessInfo.resetTimeout()
396+
sessInfo.transport.ServeHTTP(w, req)
332397
}
333398

334399
// A StreamableServerTransport implements the server side of the MCP streamable
@@ -1340,9 +1405,12 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
13401405
go c.handleJSON(requestSummary, resp)
13411406

13421407
case "text/event-stream":
1343-
jsonReq, _ := msg.(*jsonrpc.Request)
1408+
var forCall *jsonrpc.Request
1409+
if jsonReq, ok := msg.(*jsonrpc.Request); ok && jsonReq.IsCall() {
1410+
forCall = jsonReq
1411+
}
13441412
// TODO: should we cancel this logical SSE request if/when jsonReq is canceled?
1345-
go c.handleSSE(requestSummary, resp, false, jsonReq)
1413+
go c.handleSSE(requestSummary, resp, false, forCall)
13461414

13471415
default:
13481416
resp.Body.Close()
@@ -1392,9 +1460,9 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
13921460
// handleSSE manages the lifecycle of an SSE connection. It can be either
13931461
// persistent (for the main GET listener) or temporary (for a POST response).
13941462
//
1395-
// If forReq is set, it is the request that initiated the stream, and the
1463+
// If forCall is set, it is the call that initiated the stream, and the
13961464
// stream is complete when we receive its response.
1397-
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
1465+
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forCall *jsonrpc2.Request) {
13981466
resp := initialResp
13991467
var lastEventID string
14001468
for {
@@ -1404,7 +1472,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14041472
// Eventually, if we don't get the response, we should stop trying and
14051473
// fail the request.
14061474
if resp != nil {
1407-
eventID, clientClosed := c.processStream(requestSummary, resp, forReq)
1475+
eventID, clientClosed := c.processStream(requestSummary, resp, forCall)
14081476
lastEventID = eventID
14091477

14101478
// If the connection was closed by the client, we're done.
@@ -1467,11 +1535,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14671535
// incoming channel. It returns the ID of the last processed event and a flag
14681536
// indicating if the connection was closed by the client. If resp is nil, it
14691537
// returns "", false.
1470-
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
1538+
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forCall *jsonrpc.Request) (lastEventID string, clientClosed bool) {
14711539
defer resp.Body.Close()
14721540
for evt, err := range scanEvents(resp.Body) {
14731541
if err != nil {
1474-
return lastEventID, false
1542+
break
14751543
}
14761544

14771545
if evt.ID != "" {
@@ -1486,10 +1554,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
14861554

14871555
select {
14881556
case c.incoming <- msg:
1489-
if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil {
1557+
if jsonResp, ok := msg.(*jsonrpc.Response); ok && forCall != nil {
14901558
// TODO: we should never get a response when forReq is nil (the standalone SSE request).
14911559
// We should detect this case.
1492-
if jsonResp.ID == forReq.ID {
1560+
if jsonResp.ID == forCall.ID {
14931561
return "", true
14941562
}
14951563
}
@@ -1499,7 +1567,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
14991567
}
15001568
}
15011569
// The loop finished without an error, indicating the server closed the stream.
1502-
return "", false
1570+
//
1571+
// If the lastEventID is "", the stream is not retryable and we should
1572+
// report a synthetic error for the call.
1573+
if lastEventID == "" && forCall != nil {
1574+
errmsg := &jsonrpc2.Response{
1575+
ID: forCall.ID,
1576+
Error: fmt.Errorf("request terminated without response"),
1577+
}
1578+
select {
1579+
case c.incoming <- errmsg:
1580+
case <-c.done:
1581+
}
1582+
}
1583+
return lastEventID, false
15031584
}
15041585

15051586
// reconnect handles the logic of retrying a connection with an exponential

0 commit comments

Comments
 (0)