Skip to content

Commit 9d54c73

Browse files
committed
add CloseWithError
1 parent d8cf4e7 commit 9d54c73

File tree

2 files changed

+92
-22
lines changed

2 files changed

+92
-22
lines changed

session.go

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,15 @@ type Session struct {
102102
// recvDoneCh is closed when recv() exits to avoid a race
103103
// between stream registration and stream shutdown
104104
recvDoneCh chan struct{}
105+
// recvErr is the error the receive loop ended with
106+
recvErr error
105107

106108
// sendDoneCh is closed when send() exits to avoid a race
107109
// between returning from a Stream.Write and exiting from the send loop
108110
// (which may be reading a buffer on-load-from Stream.Write).
109111
sendDoneCh chan struct{}
112+
// sendErr is the error the send loop ended with
113+
sendErr error
110114

111115
// client is true if we're the client and our stream IDs should be odd.
112116
client bool
@@ -288,10 +292,18 @@ func (s *Session) AcceptStream() (*Stream, error) {
288292
// semantics of the underlying net.Conn. For TCP connections, it may be dropped depending on LINGER value or
289293
// if there's unread data in the kernel receive buffer.
290294
func (s *Session) Close() error {
291-
return s.close(true, goAwayNormal)
295+
return s.closeWithGoAway(goAwayNormal)
292296
}
293297

294-
func (s *Session) close(sendGoAway bool, errCode uint32) error {
298+
// CloseWithError is used to close the session and all streams after sending a GoAway message with errCode.
299+
// The GoAway may not actually be sent depending on the semantics of the underlying net.Conn.
300+
// For TCP connections, it may be dropped depending on LINGER value or if there's unread data in the kernel
301+
// receive buffer.
302+
func (s *Session) CloseWithError(errCode uint32) error {
303+
return s.closeWithGoAway(errCode)
304+
}
305+
306+
func (s *Session) closeWithGoAway(errCode uint32) error {
295307
s.shutdownLock.Lock()
296308
defer s.shutdownLock.Unlock()
297309

@@ -308,14 +320,12 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
308320
// wait for write loop to exit
309321
_ = s.conn.SetWriteDeadline(time.Now().Add(-1 * time.Hour)) // if SetWriteDeadline errored, any blocked writes will be unblocked
310322
<-s.sendDoneCh
311-
if sendGoAway {
312-
ga := s.goAway(errCode)
313-
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
314-
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
315-
}
323+
ga := s.goAway(errCode)
324+
if err := s.conn.SetWriteDeadline(time.Now().Add(goAwayWaitTime)); err == nil {
325+
_, _ = s.conn.Write(ga[:]) // there's nothing we can do on error here
316326
}
317-
318327
s.conn.SetWriteDeadline(time.Time{})
328+
319329
s.conn.Close()
320330
<-s.recvDoneCh
321331

@@ -329,15 +339,37 @@ func (s *Session) close(sendGoAway bool, errCode uint32) error {
329339
return nil
330340
}
331341

332-
// exitErr is used to handle an error that is causing the
333-
// session to terminate.
334-
func (s *Session) exitErr(err error) {
342+
func (s *Session) closeWithoutGoAway(err error) error {
335343
s.shutdownLock.Lock()
344+
defer s.shutdownLock.Unlock()
345+
if s.shutdown {
346+
return nil
347+
}
348+
s.shutdown = true
336349
if s.shutdownErr == nil {
337350
s.shutdownErr = err
338351
}
339-
s.shutdownLock.Unlock()
340-
s.close(false, 0)
352+
353+
s.conn.Close()
354+
<-s.recvDoneCh
355+
// Prefer the recvLoop error over the sendLoop error. The receive loop might have the error code
356+
// received in a GoAway frame received just before the RST that closed the sendLoop
357+
if _, ok := s.recvErr.(*GoAwayError); ok {
358+
s.shutdownErr = s.recvErr
359+
}
360+
close(s.shutdownCh)
361+
362+
s.stopKeepalive()
363+
<-s.sendDoneCh
364+
365+
s.streamLock.Lock()
366+
defer s.streamLock.Unlock()
367+
for id, stream := range s.streams {
368+
stream.forceClose()
369+
delete(s.streams, id)
370+
stream.memorySpan.Done()
371+
}
372+
return nil
341373
}
342374

343375
// GoAway can be used to prevent accepting further
@@ -468,7 +500,12 @@ func (s *Session) startKeepalive() {
468500

469501
if err != nil {
470502
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
471-
s.exitErr(ErrKeepAliveTimeout)
503+
s.shutdownLock.Lock()
504+
if s.shutdownErr == nil {
505+
s.shutdownErr = ErrKeepAliveTimeout
506+
}
507+
s.shutdownLock.Unlock()
508+
s.closeWithGoAway(goAwayNormal)
472509
}
473510
})
474511
}
@@ -533,7 +570,7 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
533570
// send is a long running goroutine that sends data
534571
func (s *Session) send() {
535572
if err := s.sendLoop(); err != nil {
536-
s.exitErr(err)
573+
s.closeWithoutGoAway(err)
537574
}
538575
}
539576

@@ -661,7 +698,7 @@ func (s *Session) sendLoop() (err error) {
661698
// recv is a long running goroutine that accepts new data
662699
func (s *Session) recv() {
663700
if err := s.recvLoop(); err != nil {
664-
s.exitErr(err)
701+
s.closeWithoutGoAway(err)
665702
}
666703
}
667704

@@ -683,7 +720,10 @@ func (s *Session) recvLoop() (err error) {
683720
err = fmt.Errorf("panic in yamux receive loop: %s", rerr)
684721
}
685722
}()
686-
defer close(s.recvDoneCh)
723+
defer func() {
724+
s.recvErr = err
725+
close(s.recvDoneCh)
726+
}()
687727
var hdr header
688728
for {
689729
// fmt.Printf("ReadFull from %#v\n", s.reader)
@@ -799,17 +839,17 @@ func (s *Session) handleGoAway(hdr header) error {
799839
switch code {
800840
case goAwayNormal:
801841
atomic.SwapInt32(&s.remoteGoAway, 1)
842+
// Don't close connection on normal go away. Let the existing streams
843+
// complete gracefully.
844+
return nil
802845
case goAwayProtoErr:
803846
s.logger.Printf("[ERR] yamux: received protocol error go away")
804-
return fmt.Errorf("yamux protocol error")
805847
case goAwayInternalErr:
806848
s.logger.Printf("[ERR] yamux: received internal error go away")
807-
return fmt.Errorf("remote yamux internal error")
808849
default:
809-
s.logger.Printf("[ERR] yamux: received unexpected go away")
810-
return fmt.Errorf("unexpected go away received")
850+
s.logger.Printf("[ERR] yamux: received go away with error code: %d", code)
811851
}
812-
return nil
852+
return &GoAwayError{Remote: true, ErrorCode: code}
813853
}
814854

815855
// incomingStream is used to create a new incoming stream

session_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package yamux
33
import (
44
"bytes"
55
"context"
6+
"errors"
67
"fmt"
78
"io"
89
"math/rand"
@@ -650,6 +651,35 @@ func TestGoAway(t *testing.T) {
650651
default:
651652
t.Fatalf("err: %v", err)
652653
}
654+
time.Sleep(50 * time.Millisecond)
655+
}
656+
t.Fatalf("expected GoAway error")
657+
}
658+
659+
func TestCloseWithError(t *testing.T) {
660+
// This test is noisy.
661+
conf := testConf()
662+
conf.LogOutput = io.Discard
663+
664+
client, server := testClientServerConfig(conf)
665+
defer client.Close()
666+
defer server.Close()
667+
668+
if err := server.CloseWithError(42); err != nil {
669+
t.Fatalf("err: %v", err)
670+
}
671+
672+
for i := 0; i < 100; i++ {
673+
s, err := client.Open(context.Background())
674+
if err == nil {
675+
s.Close()
676+
time.Sleep(50 * time.Millisecond)
677+
continue
678+
}
679+
if !errors.Is(err, &GoAwayError{ErrorCode: 42, Remote: true}) {
680+
t.Fatalf("err: %v", err)
681+
}
682+
return
653683
}
654684
t.Fatalf("expected GoAway error")
655685
}

0 commit comments

Comments
 (0)