@@ -102,11 +102,15 @@ type Session struct {
102102 // recvDoneCh is closed when recv() exits to avoid a race
103103 // between stream registration and stream shutdown
104104 recvDoneCh chan struct {}
105+ // recvErr is the error the receive loop ended with
106+ recvErr error
105107
106108 // sendDoneCh is closed when send() exits to avoid a race
107109 // between returning from a Stream.Write and exiting from the send loop
108110 // (which may be reading a buffer on-load-from Stream.Write).
109111 sendDoneCh chan struct {}
112+ // sendErr is the error the send loop ended with
113+ sendErr error
110114
111115 // client is true if we're the client and our stream IDs should be odd.
112116 client bool
@@ -288,10 +292,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288292// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289293// if there's unread data in the kernel receive buffer.
290294func (s * Session ) Close () error {
291- return s .close ( true , goAwayNormal )
295+ return s .closeWithGoAway ( goAwayNormal )
292296}
293297
294- func (s * Session ) close (sendGoAway bool , errCode uint32 ) error {
298+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
299+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
300+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
301+ // receive buffer.
302+ func (s * Session ) CloseWithError (errCode uint32 ) error {
303+ return s .closeWithGoAway (errCode )
304+ }
305+
306+ func (s * Session ) closeWithGoAway (errCode uint32 ) error {
295307 s .shutdownLock .Lock ()
296308 defer s .shutdownLock .Unlock ()
297309
@@ -308,14 +320,12 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
308320 // wait for write loop to exit
309321 _ = s .conn .SetWriteDeadline (time .Now ().Add (- 1 * time .Hour )) // if SetWriteDeadline errored, any blocked writes will be unblocked
310322 <- s .sendDoneCh
311- if sendGoAway {
312- ga := s .goAway (errCode )
313- if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
314- _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
315- }
323+ ga := s .goAway (errCode )
324+ if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
325+ _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
316326 }
317-
318327 s .conn .SetWriteDeadline (time.Time {})
328+
319329 s .conn .Close ()
320330 <- s .recvDoneCh
321331
@@ -329,15 +339,37 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329339 return nil
330340}
331341
332- // exitErr is used to handle an error that is causing the
333- // session to terminate.
334- func (s * Session ) exitErr (err error ) {
342+ func (s * Session ) closeWithoutGoAway (err error ) error {
335343 s .shutdownLock .Lock ()
344+ defer s .shutdownLock .Unlock ()
345+ if s .shutdown {
346+ return nil
347+ }
348+ s .shutdown = true
336349 if s .shutdownErr == nil {
337350 s .shutdownErr = err
338351 }
339- s .shutdownLock .Unlock ()
340- s .close (false , 0 )
352+
353+ s .conn .Close ()
354+ <- s .recvDoneCh
355+ // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
356+ // received in a GoAway frame received just before the RST that closed the sendLoop
357+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
358+ s .shutdownErr = s .recvErr
359+ }
360+ close (s .shutdownCh )
361+
362+ s .stopKeepalive ()
363+ <- s .sendDoneCh
364+
365+ s .streamLock .Lock ()
366+ defer s .streamLock .Unlock ()
367+ for id , stream := range s .streams {
368+ stream .forceClose ()
369+ delete (s .streams , id )
370+ stream .memorySpan .Done ()
371+ }
372+ return nil
341373}
342374
343375// GoAway can be used to prevent accepting further
@@ -468,7 +500,12 @@ func (s *Session) startKeepalive() {
468500
469501 if err != nil {
470502 s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
471- s .exitErr (ErrKeepAliveTimeout )
503+ s .shutdownLock .Lock ()
504+ if s .shutdownErr == nil {
505+ s .shutdownErr = ErrKeepAliveTimeout
506+ }
507+ s .shutdownLock .Unlock ()
508+ s .closeWithGoAway (goAwayNormal )
472509 }
473510 })
474511}
@@ -533,7 +570,7 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533570// send is a long running goroutine that sends data
534571func (s * Session ) send () {
535572 if err := s .sendLoop (); err != nil {
536- s .exitErr (err )
573+ s .closeWithoutGoAway (err )
537574 }
538575}
539576
@@ -661,7 +698,7 @@ func (s *Session) sendLoop() (err error) {
661698// recv is a long running goroutine that accepts new data
662699func (s * Session ) recv () {
663700 if err := s .recvLoop (); err != nil {
664- s .exitErr (err )
701+ s .closeWithoutGoAway (err )
665702 }
666703}
667704
@@ -683,7 +720,10 @@ func (s *Session) recvLoop() (err error) {
683720 err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
684721 }
685722 }()
686- defer close (s .recvDoneCh )
723+ defer func () {
724+ s .recvErr = err
725+ close (s .recvDoneCh )
726+ }()
687727 var hdr header
688728 for {
689729 // fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +839,17 @@ func (s *Session) handleGoAway(hdr header) error {
799839 switch code {
800840 case goAwayNormal :
801841 atomic .SwapInt32 (& s .remoteGoAway , 1 )
842+ // Don't close connection on normal go away. Let the existing streams
843+ // complete gracefully.
844+ return nil
802845 case goAwayProtoErr :
803846 s .logger .Printf ("[ERR] yamux: received protocol error go away" )
804- return fmt .Errorf ("yamux protocol error" )
805847 case goAwayInternalErr :
806848 s .logger .Printf ("[ERR] yamux: received internal error go away" )
807- return fmt .Errorf ("remote yamux internal error" )
808849 default :
809- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
810- return fmt .Errorf ("unexpected go away received" )
850+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
811851 }
812- return nil
852+ return & GoAwayError { Remote : true , ErrorCode : code }
813853}
814854
815855// incomingStream is used to create a new incoming stream
0 commit comments