@@ -12,6 +12,7 @@ import (
1212	"fmt" 
1313	"io" 
1414	"log/slog" 
15+ 	"maps" 
1516	"math" 
1617	"math/rand/v2" 
1718	"net/http" 
@@ -40,12 +41,46 @@ type StreamableHTTPHandler struct {
4041	getServer  func (* http.Request ) * Server 
4142	opts       StreamableHTTPOptions 
4243
43- 	onTransportDeletion  func (sessionID  string ) // for testing only  
44+ 	onTransportDeletion  func (sessionID  string ) // for testing 
4445
45- 	mu  sync.Mutex 
46- 	// TODO: we should store the ServerSession along with the transport, because 
47- 	// we need to cancel keepalive requests when closing the transport. 
48- 	transports  map [string ]* StreamableServerTransport  // keyed by IDs (from Mcp-Session-Id header) 
46+ 	mu        sync.Mutex 
47+ 	sessions  map [string ]* sessionInfo  // keyed by session ID 
48+ }
49+ 
50+ type  sessionInfo  struct  {
51+ 	session    * ServerSession 
52+ 	transport  * StreamableServerTransport 
53+ 
54+ 	// If timeout is set, automatically close the session after an idle period. 
55+ 	timeout  time.Duration 
56+ 	timerMu  sync.Mutex 
57+ 	timer    * time.Timer 
58+ }
59+ 
60+ // resetTimeout resets the inactivity timer. 
61+ func  (i  * sessionInfo ) resetTimeout () {
62+ 	if  i .timeout  <=  0  {
63+ 		return 
64+ 	}
65+ 
66+ 	i .timerMu .Lock ()
67+ 	defer  i .timerMu .Unlock ()
68+ 
69+ 	if  i .timer  ==  nil  {
70+ 		return 
71+ 	}
72+ 	// Reset the timer if we successfully stopped it. 
73+ 	i .timer .Reset (i .timeout )
74+ }
75+ 
76+ // stopTimer stops the inactivity timer. 
77+ func  (i  * sessionInfo ) stopTimer () {
78+ 	i .timerMu .Lock ()
79+ 	defer  i .timerMu .Unlock ()
80+ 	if  i .timer  !=  nil  {
81+ 		i .timer .Stop ()
82+ 		i .timer  =  nil 
83+ 	}
4984}
5085
5186// StreamableHTTPOptions configures the StreamableHTTPHandler. 
@@ -72,6 +107,14 @@ type StreamableHTTPOptions struct {
72107	// If nil, do not log. 
73108	Logger  * slog.Logger 
74109
110+ 	// SessionTimeout configures a timeout for idle sessions. 
111+ 	// 
112+ 	// When sessions receive no new HTTP requests from the client for this 
113+ 	// duration, they are automatically closed. 
114+ 	// 
115+ 	// If SessionTimeout is the zero value, idle sessions are never closed. 
116+ 	SessionTimeout  time.Duration 
117+ 
75118	// TODO(rfindley): file a proposal to export this option, or something equivalent. 
76119	configureTransport  func (req  * http.Request , transport  * StreamableServerTransport )
77120}
@@ -83,8 +126,8 @@ type StreamableHTTPOptions struct {
83126// If getServer returns nil, a 400 Bad Request will be served. 
84127func  NewStreamableHTTPHandler (getServer  func (* http.Request ) * Server , opts  * StreamableHTTPOptions ) * StreamableHTTPHandler  {
85128	h  :=  & StreamableHTTPHandler {
86- 		getServer :   getServer ,
87- 		transports :  make (map [string ]* StreamableServerTransport ),
129+ 		getServer : getServer ,
130+ 		sessions :   make (map [string ]* sessionInfo ),
88131	}
89132	if  opts  !=  nil  {
90133		h .opts  =  * opts 
@@ -97,20 +140,27 @@ func NewStreamableHTTPHandler(getServer func(*http.Request) *Server, opts *Strea
97140	return  h 
98141}
99142
100- // closeAll closes all ongoing sessions. 
143+ // closeAll closes all ongoing sessions, for tests . 
101144// 
102145// TODO(rfindley): investigate the best API for callers to configure their 
103146// session lifecycle. (?) 
104147// 
105148// Should we allow passing in a session store? That would allow the handler to 
106149// be stateless. 
107150func  (h  * StreamableHTTPHandler ) closeAll () {
151+ 	// TODO: if we ever expose this outside of tests, we'll need to do better 
152+ 	// than simply collecting sessions while holding the lock: we need to prevent 
153+ 	// new sessions from being added. 
154+ 	// 
155+ 	// Currently, sessions remove themselves from h.sessions when closed, so we 
156+ 	// can't call Close while holding the lock. 
108157	h .mu .Lock ()
109- 	defer  h .mu .Unlock ()
110- 	for  _ , s  :=  range  h .transports  {
111- 		s .connection .Close ()
158+ 	sessionInfos  :=  slices .Collect (maps .Values (h .sessions ))
159+ 	h .sessions  =  nil 
160+ 	h .mu .Unlock ()
161+ 	for  _ , s  :=  range  sessionInfos  {
162+ 		s .session .Close ()
112163	}
113- 	h .transports  =  nil 
114164}
115165
116166func  (h  * StreamableHTTPHandler ) ServeHTTP (w  http.ResponseWriter , req  * http.Request ) {
@@ -141,12 +191,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
141191	}
142192
143193	sessionID  :=  req .Header .Get (sessionIDHeader )
144- 	var  transport   * StreamableServerTransport 
194+ 	var  sessInfo   * sessionInfo 
145195	if  sessionID  !=  ""  {
146196		h .mu .Lock ()
147- 		transport  =  h .transports [sessionID ]
197+ 		sessInfo  =  h .sessions [sessionID ]
148198		h .mu .Unlock ()
149- 		if  transport  ==  nil  &&  ! h .opts .Stateless  {
199+ 		if  sessInfo  ==  nil  &&  ! h .opts .Stateless  {
150200			// Unless we're in 'stateless' mode, which doesn't perform any Session-ID 
151201			// validation, we require that the session ID matches a known session. 
152202			// 
@@ -161,11 +211,10 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
161211			http .Error (w , "Bad Request: DELETE requires an Mcp-Session-Id header" , http .StatusBadRequest )
162212			return 
163213		}
164- 		if  transport  !=  nil  { // transport may be nil in stateless mode 
165- 			h .mu .Lock ()
166- 			delete (h .transports , transport .SessionID )
167- 			h .mu .Unlock ()
168- 			transport .connection .Close ()
214+ 		if  sessInfo  !=  nil  { // sessInfo may be nil in stateless mode 
215+ 			// Closing the session also removes it from h.sessions, due to the 
216+ 			// onClose callback. 
217+ 			sessInfo .session .Close ()
169218		}
170219		w .WriteHeader (http .StatusNoContent )
171220		return 
@@ -222,7 +271,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
222271		return 
223272	}
224273
225- 	if  transport  ==  nil  {
274+ 	if  sessInfo  ==  nil  {
226275		server  :=  h .getServer (req )
227276		if  server  ==  nil  {
228277			// The getServer argument to NewStreamableHTTPHandler returned nil. 
@@ -234,7 +283,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
234283			// existing transport. 
235284			sessionID  =  server .opts .GetSessionID ()
236285		}
237- 		transport  =  & StreamableServerTransport {
286+ 		transport  : =& StreamableServerTransport {
238287			SessionID :    sessionID ,
239288			Stateless :    h .opts .Stateless ,
240289			jsonResponse : h .opts .JSONResponse ,
@@ -300,10 +349,13 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
300349			connectOpts  =  & ServerSessionOptions {
301350				onClose : func () {
302351					h .mu .Lock ()
303- 					delete (h .transports , transport .SessionID )
304- 					h .mu .Unlock ()
305- 					if  h .onTransportDeletion  !=  nil  {
306- 						h .onTransportDeletion (transport .SessionID )
352+ 					defer  h .mu .Unlock ()
353+ 					if  info , ok  :=  h .sessions [transport .SessionID ]; ok  {
354+ 						info .stopTimer ()
355+ 						delete (h .sessions , transport .SessionID )
356+ 						if  h .onTransportDeletion  !=  nil  {
357+ 							h .onTransportDeletion (transport .SessionID )
358+ 						}
307359					}
308360				},
309361			}
@@ -312,23 +364,36 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
312364		// Pass req.Context() here, to allow middleware to add context values. 
313365		// The context is detached in the jsonrpc2 library when handling the 
314366		// long-running stream. 
315- 		ss , err  :=  server .Connect (req .Context (), transport , connectOpts )
367+ 		session , err  :=  server .Connect (req .Context (), transport , connectOpts )
316368		if  err  !=  nil  {
317369			http .Error (w , "failed connection" , http .StatusInternalServerError )
318370			return 
319371		}
372+ 		sessInfo  =  & sessionInfo {
373+ 			session :   session ,
374+ 			transport : transport ,
375+ 		}
320376		if  h .opts .Stateless  {
321377			// Stateless mode: close the session when the request exits. 
322- 			defer  ss .Close () // close the fake session after handling the request 
378+ 			defer  session .Close () // close the fake session after handling the request 
323379		} else  {
324380			// Otherwise, save the transport so that it can be reused 
381+ 
382+ 			// Clean up the session when it times out. 
383+ 			if  h .opts .SessionTimeout  >  0  {
384+ 				sessInfo .timeout  =  h .opts .SessionTimeout 
385+ 				sessInfo .timer  =  time .AfterFunc (sessInfo .timeout , func () {
386+ 					sessInfo .session .Close ()
387+ 				})
388+ 			}
325389			h .mu .Lock ()
326- 			h .transports [transport .SessionID ] =  transport 
390+ 			h .sessions [transport .SessionID ] =  sessInfo 
327391			h .mu .Unlock ()
328392		}
329393	}
330394
331- 	transport .ServeHTTP (w , req )
395+ 	sessInfo .resetTimeout ()
396+ 	sessInfo .transport .ServeHTTP (w , req )
332397}
333398
334399// A StreamableServerTransport implements the server side of the MCP streamable 
@@ -1340,9 +1405,12 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
13401405		go  c .handleJSON (requestSummary , resp )
13411406
13421407	case  "text/event-stream" :
1343- 		jsonReq , _  :=  msg .(* jsonrpc.Request )
1408+ 		var  forCall  * jsonrpc.Request 
1409+ 		if  jsonReq , ok  :=  msg .(* jsonrpc.Request ); ok  &&  jsonReq .IsCall () {
1410+ 			forCall  =  jsonReq 
1411+ 		}
13441412		// TODO: should we cancel this logical SSE request if/when jsonReq is canceled? 
1345- 		go  c .handleSSE (requestSummary , resp , false , jsonReq )
1413+ 		go  c .handleSSE (requestSummary , resp , false , forCall )
13461414
13471415	default :
13481416		resp .Body .Close ()
@@ -1392,9 +1460,9 @@ func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Resp
13921460// handleSSE manages the lifecycle of an SSE connection. It can be either 
13931461// persistent (for the main GET listener) or temporary (for a POST response). 
13941462// 
1395- // If forReq  is set, it is the request  that initiated the stream, and the 
1463+ // If forCall  is set, it is the call  that initiated the stream, and the 
13961464// stream is complete when we receive its response. 
1397- func  (c  * streamableClientConn ) handleSSE (requestSummary  string , initialResp  * http.Response , persistent  bool , forReq  * jsonrpc2.Request ) {
1465+ func  (c  * streamableClientConn ) handleSSE (requestSummary  string , initialResp  * http.Response , persistent  bool , forCall  * jsonrpc2.Request ) {
13981466	resp  :=  initialResp 
13991467	var  lastEventID  string 
14001468	for  {
@@ -1404,7 +1472,7 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14041472		// Eventually, if we don't get the response, we should stop trying and 
14051473		// fail the request. 
14061474		if  resp  !=  nil  {
1407- 			eventID , clientClosed  :=  c .processStream (requestSummary , resp , forReq )
1475+ 			eventID , clientClosed  :=  c .processStream (requestSummary , resp , forCall )
14081476			lastEventID  =  eventID 
14091477
14101478			// If the connection was closed by the client, we're done. 
@@ -1467,11 +1535,11 @@ func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *htt
14671535// incoming channel. It returns the ID of the last processed event and a flag 
14681536// indicating if the connection was closed by the client. If resp is nil, it 
14691537// returns "", false. 
1470- func  (c  * streamableClientConn ) processStream (requestSummary  string , resp  * http.Response , forReq  * jsonrpc.Request ) (lastEventID  string , clientClosed  bool ) {
1538+ func  (c  * streamableClientConn ) processStream (requestSummary  string , resp  * http.Response , forCall  * jsonrpc.Request ) (lastEventID  string , clientClosed  bool ) {
14711539	defer  resp .Body .Close ()
14721540	for  evt , err  :=  range  scanEvents (resp .Body ) {
14731541		if  err  !=  nil  {
1474- 			return   lastEventID ,  false 
1542+ 			break 
14751543		}
14761544
14771545		if  evt .ID  !=  ""  {
@@ -1486,10 +1554,10 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
14861554
14871555		select  {
14881556		case  c .incoming  <-  msg :
1489- 			if  jsonResp , ok  :=  msg .(* jsonrpc.Response ); ok  &&  forReq  !=  nil  {
1557+ 			if  jsonResp , ok  :=  msg .(* jsonrpc.Response ); ok  &&  forCall  !=  nil  {
14901558				// TODO: we should never get a response when forReq is nil (the standalone SSE request). 
14911559				// We should detect this case. 
1492- 				if  jsonResp .ID  ==  forReq .ID  {
1560+ 				if  jsonResp .ID  ==  forCall .ID  {
14931561					return  "" , true 
14941562				}
14951563			}
@@ -1499,7 +1567,20 @@ func (c *streamableClientConn) processStream(requestSummary string, resp *http.R
14991567		}
15001568	}
15011569	// The loop finished without an error, indicating the server closed the stream. 
1502- 	return  "" , false 
1570+ 	// 
1571+ 	// If the lastEventID is "", the stream is not retryable and we should 
1572+ 	// report a synthetic error for the call. 
1573+ 	if  lastEventID  ==  ""  &&  forCall  !=  nil  {
1574+ 		errmsg  :=  & jsonrpc2.Response {
1575+ 			ID :    forCall .ID ,
1576+ 			Error : fmt .Errorf ("request terminated without response" ),
1577+ 		}
1578+ 		select  {
1579+ 		case  c .incoming  <-  errmsg :
1580+ 		case  <- c .done :
1581+ 		}
1582+ 	}
1583+ 	return  lastEventID , false 
15031584}
15041585
15051586// reconnect handles the logic of retrying a connection with an exponential 
0 commit comments