@@ -102,6 +102,8 @@ type Session struct {
102102 // recvDoneCh is closed when recv() exits to avoid a race
103103 // between stream registration and stream shutdown
104104 recvDoneCh chan struct {}
105+ // recvErr is the error the receive loop ended with
106+ recvErr error
105107
106108 // sendDoneCh is closed when send() exits to avoid a race
107109 // between returning from a Stream.Write and exiting from the send loop
@@ -288,10 +290,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288290// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289291// if there's unread data in the kernel receive buffer.
290292func (s * Session ) Close () error {
291- return s .close ( true , goAwayNormal )
293+ return s .closeWithGoAway ( goAwayNormal )
292294}
293295
294- func (s * Session ) close (sendGoAway bool , errCode uint32 ) error {
296+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
297+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
298+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
299+ // receive buffer.
300+ func (s * Session ) CloseWithError (errCode uint32 ) error {
301+ return s .closeWithGoAway (errCode )
302+ }
303+
304+ func (s * Session ) closeWithGoAway (errCode uint32 ) error {
295305 s .shutdownLock .Lock ()
296306 defer s .shutdownLock .Unlock ()
297307
@@ -308,14 +318,12 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
308318 // wait for write loop to exit
309319 _ = s .conn .SetWriteDeadline (time .Now ().Add (- 1 * time .Hour )) // if SetWriteDeadline errored, any blocked writes will be unblocked
310320 <- s .sendDoneCh
311- if sendGoAway {
312- ga := s .goAway (errCode )
313- if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
314- _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
315- }
321+ ga := s .goAway (errCode )
322+ if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
323+ _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
316324 }
317-
318325 s .conn .SetWriteDeadline (time.Time {})
326+
319327 s .conn .Close ()
320328 <- s .recvDoneCh
321329
@@ -329,15 +337,37 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329337 return nil
330338}
331339
332- // exitErr is used to handle an error that is causing the
333- // session to terminate.
334- func (s * Session ) exitErr (err error ) {
340+ func (s * Session ) closeWithoutGoAway (err error ) error {
335341 s .shutdownLock .Lock ()
342+ defer s .shutdownLock .Unlock ()
343+ if s .shutdown {
344+ return nil
345+ }
346+ s .shutdown = true
336347 if s .shutdownErr == nil {
337348 s .shutdownErr = err
338349 }
339- s .shutdownLock .Unlock ()
340- s .close (false , 0 )
350+
351+ s .conn .Close ()
352+ <- s .recvDoneCh
353+ // Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
354+ // received in a GoAway frame received just before the RST that closed the sendLoop
355+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
356+ s .shutdownErr = s .recvErr
357+ }
358+ close (s .shutdownCh )
359+
360+ s .stopKeepalive ()
361+ <- s .sendDoneCh
362+
363+ s .streamLock .Lock ()
364+ defer s .streamLock .Unlock ()
365+ for id , stream := range s .streams {
366+ stream .forceClose ()
367+ delete (s .streams , id )
368+ stream .memorySpan .Done ()
369+ }
370+ return nil
341371}
342372
343373// GoAway can be used to prevent accepting further
@@ -468,7 +498,12 @@ func (s *Session) startKeepalive() {
468498
469499 if err != nil {
470500 s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
471- s .exitErr (ErrKeepAliveTimeout )
501+ s .shutdownLock .Lock ()
502+ if s .shutdownErr == nil {
503+ s .shutdownErr = ErrKeepAliveTimeout
504+ }
505+ s .shutdownLock .Unlock ()
506+ s .closeWithGoAway (goAwayNormal )
472507 }
473508 })
474509}
@@ -533,7 +568,7 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533568// send is a long running goroutine that sends data
534569func (s * Session ) send () {
535570 if err := s .sendLoop (); err != nil {
536- s .exitErr (err )
571+ s .closeWithoutGoAway (err )
537572 }
538573}
539574
@@ -661,7 +696,7 @@ func (s *Session) sendLoop() (err error) {
661696// recv is a long running goroutine that accepts new data
662697func (s * Session ) recv () {
663698 if err := s .recvLoop (); err != nil {
664- s .exitErr (err )
699+ s .closeWithoutGoAway (err )
665700 }
666701}
667702
@@ -683,7 +718,10 @@ func (s *Session) recvLoop() (err error) {
683718 err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
684719 }
685720 }()
686- defer close (s .recvDoneCh )
721+ defer func () {
722+ s .recvErr = err
723+ close (s .recvDoneCh )
724+ }()
687725 var hdr header
688726 for {
689727 // fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +837,17 @@ func (s *Session) handleGoAway(hdr header) error {
799837 switch code {
800838 case goAwayNormal :
801839 atomic .SwapInt32 (& s .remoteGoAway , 1 )
840+ // Don't close connection on normal go away. Let the existing streams
841+ // complete gracefully.
842+ return nil
802843 case goAwayProtoErr :
803844 s .logger .Printf ("[ERR] yamux: received protocol error go away" )
804- return fmt .Errorf ("yamux protocol error" )
805845 case goAwayInternalErr :
806846 s .logger .Printf ("[ERR] yamux: received internal error go away" )
807- return fmt .Errorf ("remote yamux internal error" )
808847 default :
809- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
810- return fmt .Errorf ("unexpected go away received" )
848+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
811849 }
812- return nil
850+ return & GoAwayError { Remote : true , ErrorCode : code }
813851}
814852
815853// incomingStream is used to create a new incoming stream
0 commit comments