@@ -46,10 +46,6 @@ var nullMemoryManager = &nullMemoryManagerImpl{}
4646type Session struct {
4747 rtt int64 // to be accessed atomically, in nanoseconds
4848
49- // remoteGoAway indicates the remote side does
50- // not want futher connections. Must be first for alignment.
51- remoteGoAway int32
52-
5349 // localGoAway indicates that we should stop
5450 // accepting futher connections. Must be first for alignment.
5551 localGoAway int32
@@ -102,6 +98,8 @@ type Session struct {
10298 // recvDoneCh is closed when recv() exits to avoid a race
10399 // between stream registration and stream shutdown
104100 recvDoneCh chan struct {}
101+ // recvErr is the error the receive loop ended with
102+ recvErr error
105103
106104 // sendDoneCh is closed when send() exits to avoid a race
107105 // between returning from a Stream.Write and exiting from the send loop
@@ -203,9 +201,6 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
203201 if s .IsClosed () {
204202 return nil , s .shutdownErr
205203 }
206- if atomic .LoadInt32 (& s .remoteGoAway ) == 1 {
207- return nil , ErrRemoteGoAway
208- }
209204
210205 // Block if we have too many inflight SYNs
211206 select {
@@ -283,9 +278,23 @@ func (s *Session) AcceptStream() (*Stream, error) {
283278 }
284279}
285280
286- // Close is used to close the session and all streams.
287- // Attempts to send a GoAway before closing the connection.
281+ // Close is used to close the session and all streams. It doesn't send a GoAway before
282+ // closing the connection.
288283func (s * Session ) Close () error {
284+ return s .close (ErrSessionShutdown , false , goAwayNormal )
285+ }
286+
287+ // CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
288+ // Blocks for ConnectionWriteTimeout to write the GoAway message.
289+ //
290+ // The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
291+ // For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
292+ // receive buffer.
293+ func (s * Session ) CloseWithError (errCode uint32 ) error {
294+ return s .close (& GoAwayError {Remote : false , ErrorCode : errCode }, true , errCode )
295+ }
296+
297+ func (s * Session ) close (shutdownErr error , sendGoAway bool , errCode uint32 ) error {
289298 s .shutdownLock .Lock ()
290299 defer s .shutdownLock .Unlock ()
291300
@@ -294,35 +303,42 @@ func (s *Session) Close() error {
294303 }
295304 s .shutdown = true
296305 if s .shutdownErr == nil {
297- s .shutdownErr = ErrSessionShutdown
306+ s .shutdownErr = shutdownErr
298307 }
299308 close (s .shutdownCh )
300- s .conn .Close ()
301309 s .stopKeepalive ()
302- <- s .recvDoneCh
310+
311+ // Only send GoAway if we have an error code.
312+ if sendGoAway && errCode != goAwayNormal {
313+ // wait for write loop to exit
314+ // We need to write the current frame completely before sending a goaway.
315+ // This will wait for at most s.config.ConnectionWriteTimeout
316+ <- s .sendDoneCh
317+ ga := s .goAway (errCode )
318+ if err := s .conn .SetWriteDeadline (time .Now ().Add (goAwayWaitTime )); err == nil {
319+ _ , _ = s .conn .Write (ga [:]) // there's nothing we can do on error here
320+ }
321+ s .conn .SetWriteDeadline (time.Time {})
322+ }
323+
324+ s .conn .Close ()
303325 <- s .sendDoneCh
326+ <- s .recvDoneCh
304327
328+ resetErr := shutdownErr
329+ if _ , ok := resetErr .(* GoAwayError ); ! ok {
330+ resetErr = fmt .Errorf ("%w: connection closed: %w" , ErrStreamReset , shutdownErr )
331+ }
305332 s .streamLock .Lock ()
306333 defer s .streamLock .Unlock ()
307334 for id , stream := range s .streams {
308- stream .forceClose ()
335+ stream .forceClose (resetErr )
309336 delete (s .streams , id )
310337 stream .memorySpan .Done ()
311338 }
312339 return nil
313340}
314341
315- // exitErr is used to handle an error that is causing the
316- // session to terminate.
317- func (s * Session ) exitErr (err error ) {
318- s .shutdownLock .Lock ()
319- if s .shutdownErr == nil {
320- s .shutdownErr = err
321- }
322- s .shutdownLock .Unlock ()
323- s .Close ()
324- }
325-
326342// GoAway can be used to prevent accepting further
327343// connections. It does not close the underlying conn.
328344func (s * Session ) GoAway () error {
@@ -451,7 +467,7 @@ func (s *Session) startKeepalive() {
451467
452468 if err != nil {
453469 s .logger .Printf ("[ERR] yamux: keepalive failed: %v" , err )
454- s .exitErr (ErrKeepAliveTimeout )
470+ s .close (ErrKeepAliveTimeout , false , 0 )
455471 }
456472 })
457473}
@@ -516,7 +532,25 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
516532// send is a long running goroutine that sends data
517533func (s * Session ) send () {
518534 if err := s .sendLoop (); err != nil {
519- s .exitErr (err )
535+ // If we are shutting down because remote closed the connection, prefer the recvLoop error
536+ // over the sendLoop error. The receive loop might have error code received in a GoAway frame,
537+ // which was received just before the TCP RST that closed the sendLoop.
538+ //
539+ // If we are closing because of an write error, we use the error from the sendLoop and not the recvLoop.
540+ // We hold the shutdownLock, close the connection, and wait for the receive loop to finish and
541+ // use the sendLoop error. Holding the shutdownLock ensures that the recvLoop doesn't trigger connection close
542+ // but the sendLoop does.
543+ s .shutdownLock .Lock ()
544+ if s .shutdownErr == nil {
545+ s .conn .Close ()
546+ <- s .recvDoneCh
547+ if _ , ok := s .recvErr .(* GoAwayError ); ok {
548+ err = s .recvErr
549+ }
550+ s .shutdownErr = err
551+ }
552+ s .shutdownLock .Unlock ()
553+ s .close (err , false , 0 )
520554 }
521555}
522556
@@ -644,7 +678,7 @@ func (s *Session) sendLoop() (err error) {
644678// recv is a long running goroutine that accepts new data
645679func (s * Session ) recv () {
646680 if err := s .recvLoop (); err != nil {
647- s .exitErr (err )
681+ s .close (err , false , 0 )
648682 }
649683}
650684
@@ -666,7 +700,10 @@ func (s *Session) recvLoop() (err error) {
666700 err = fmt .Errorf ("panic in yamux receive loop: %s" , rerr )
667701 }
668702 }()
669- defer close (s .recvDoneCh )
703+ defer func () {
704+ s .recvErr = err
705+ close (s .recvDoneCh )
706+ }()
670707 var hdr header
671708 for {
672709 // fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -781,18 +818,15 @@ func (s *Session) handleGoAway(hdr header) error {
781818 code := hdr .Length ()
782819 switch code {
783820 case goAwayNormal :
784- atomic . SwapInt32 ( & s . remoteGoAway , 1 )
821+ return ErrRemoteGoAway
785822 case goAwayProtoErr :
786823 s .logger .Printf ("[ERR] yamux: received protocol error go away" )
787- return fmt .Errorf ("yamux protocol error" )
788824 case goAwayInternalErr :
789825 s .logger .Printf ("[ERR] yamux: received internal error go away" )
790- return fmt .Errorf ("remote yamux internal error" )
791826 default :
792- s .logger .Printf ("[ERR] yamux: received unexpected go away" )
793- return fmt .Errorf ("unexpected go away received" )
827+ s .logger .Printf ("[ERR] yamux: received go away with error code: %d" , code )
794828 }
795- return nil
829+ return & GoAwayError { Remote : true , ErrorCode : code }
796830}
797831
798832// incomingStream is used to create a new incoming stream
0 commit comments