Skip to content

Commit 5c8df85

Browse files
dennis-trasukunrt
andauthored
fix: deadlock on close (#130)
--------- Co-authored-by: sukun <[email protected]>
1 parent 4ef00e8 commit 5c8df85

File tree

4 files changed

+134
-14
lines changed

4 files changed

+134
-14
lines changed

const.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package yamux
22

33
import (
44
"encoding/binary"
5+
"errors"
56
"fmt"
67
"time"
78
)
@@ -128,6 +129,8 @@ var (
128129

129130
// ErrKeepAliveTimeout is sent if a missed keepalive caused the stream close
130131
ErrKeepAliveTimeout = &Error{msg: "keepalive timeout", timeout: true}
132+
133+
errSendLoopDone = errors.New("send loop done")
131134
)
132135

133136
const (

session.go

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ func (s *Session) close(shutdownErr error, sendGoAway bool, errCode uint32) erro
342342
// GoAway can be used to prevent accepting further
343343
// connections. It does not close the underlying conn.
344344
func (s *Session) GoAway() error {
345-
return s.sendMsg(s.goAway(goAwayNormal), nil, nil)
345+
return s.sendMsg(s.goAway(goAwayNormal), nil, nil, true)
346346
}
347347

348348
// goAway is used to send a goAway message
@@ -499,7 +499,12 @@ func (s *Session) extendKeepalive() {
499499
}
500500

501501
// send sends the header and body.
502-
func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) error {
502+
// If waitForShutDown is true, it will wait for shutdown to complete even if the send loop has exited. This
503+
// ensures accurate error reporting. waitForShutDown should be true for callers other than the recvLoop.
504+
// The recvLoop should set waitForShutdown to false to avoid a deadlock.
505+
// For details see: https://github.com/libp2p/go-yamux/issues/129
506+
// and the test `TestSessionCloseDeadlock`
507+
func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}, waitForShutDown bool) error {
503508
select {
504509
case <-s.shutdownCh:
505510
return s.shutdownErr
@@ -521,6 +526,13 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err
521526
case <-s.shutdownCh:
522527
pool.Put(buf)
523528
return s.shutdownErr
529+
case <-s.sendDoneCh:
530+
pool.Put(buf)
531+
if waitForShutDown {
532+
<-s.shutdownCh
533+
return s.shutdownErr
534+
}
535+
return errSendLoopDone
524536
case s.sendCh <- buf:
525537
return nil
526538
case <-deadline:
@@ -773,7 +785,7 @@ func (s *Session) handleStreamMessage(hdr header) error {
773785

774786
// Read the new data
775787
if err := stream.readData(hdr, flags, s.reader); err != nil {
776-
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
788+
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil, false); sendErr != nil && sendErr != errSendLoopDone {
777789
s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
778790
}
779791
return err
@@ -838,7 +850,7 @@ func (s *Session) incomingStream(id uint32) error {
838850
// Reject immediately if we are doing a go away
839851
if atomic.LoadInt32(&s.localGoAway) == 1 {
840852
hdr := encode(typeWindowUpdate, flagRST, id, 0)
841-
return s.sendMsg(hdr, nil, nil)
853+
return s.sendMsg(hdr, nil, nil, false)
842854
}
843855

844856
// Allocate a new stream
@@ -857,7 +869,7 @@ func (s *Session) incomingStream(id uint32) error {
857869
// Check if stream already exists
858870
if _, ok := s.streams[id]; ok {
859871
s.logger.Printf("[ERR] yamux: duplicate stream declared")
860-
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
872+
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil, false); sendErr != nil && sendErr != errSendLoopDone {
861873
s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
862874
}
863875
span.Done()
@@ -869,7 +881,7 @@ func (s *Session) incomingStream(id uint32) error {
869881
s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset")
870882
defer span.Done()
871883
hdr := encode(typeWindowUpdate, flagRST, id, 0)
872-
return s.sendMsg(hdr, nil, nil)
884+
return s.sendMsg(hdr, nil, nil, false)
873885
}
874886

875887
s.numIncomingStreams++
@@ -886,7 +898,7 @@ func (s *Session) incomingStream(id uint32) error {
886898
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset")
887899
s.deleteStream(id)
888900
hdr := encode(typeWindowUpdate, flagRST, id, 0)
889-
return s.sendMsg(hdr, nil, nil)
901+
return s.sendMsg(hdr, nil, nil, false)
890902
}
891903
}
892904

session_test.go

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,7 @@ func TestSession_sendMsg_Timeout(t *testing.T) {
13591359

13601360
hdr := encode(typePing, flagACK, 0, 0)
13611361
for {
1362-
err := client.sendMsg(hdr, nil, nil)
1362+
err := client.sendMsg(hdr, nil, nil, true)
13631363
if err == nil {
13641364
continue
13651365
} else if err == ErrConnectionWriteTimeout {
@@ -1382,14 +1382,14 @@ func TestWindowOverflow(t *testing.T) {
13821382
defer server.Close()
13831383

13841384
hdr1 := encode(typeData, flagSYN, i, 0)
1385-
_ = client.sendMsg(hdr1, nil, nil)
1385+
_ = client.sendMsg(hdr1, nil, nil, true)
13861386
s, err := server.AcceptStream()
13871387
if err != nil {
13881388
t.Fatal(err)
13891389
}
13901390
msg := make([]byte, client.config.MaxStreamWindowSize*2)
13911391
hdr2 := encode(typeData, 0, i, uint32(len(msg)))
1392-
_ = client.sendMsg(hdr2, msg, nil)
1392+
_ = client.sendMsg(hdr2, msg, nil, true)
13931393
_, err = io.ReadAll(s)
13941394
if err == nil {
13951395
t.Fatal("expected to read no data")
@@ -1874,3 +1874,108 @@ func TestErrorCodeErrorIsErrStreamReset(t *testing.T) {
18741874
ge := &GoAwayError{}
18751875
require.True(t, errors.Is(ge, ErrStreamReset))
18761876
}
1877+
1878+
func testTCPConn(t *testing.T) (net.Conn, net.Conn) {
1879+
listener, err := net.Listen("tcp", "127.0.0.1:0")
1880+
if err != nil {
1881+
t.Fatal(err)
1882+
}
1883+
defer listener.Close()
1884+
1885+
// Channel to receive the server-side connection
1886+
serverConnCh := make(chan net.Conn, 1)
1887+
errCh := make(chan error, 1)
1888+
1889+
// Accept connection in goroutine
1890+
go func() {
1891+
conn, err := listener.Accept()
1892+
if err != nil {
1893+
errCh <- err
1894+
return
1895+
}
1896+
serverConnCh <- conn
1897+
}()
1898+
1899+
// Connect to the listener
1900+
clientConn, err := net.Dial("tcp", listener.Addr().String())
1901+
if err != nil {
1902+
t.Fatal(err)
1903+
}
1904+
1905+
// Wait for server connection or error
1906+
select {
1907+
case serverConn := <-serverConnCh:
1908+
return clientConn, serverConn
1909+
case err := <-errCh:
1910+
clientConn.Close()
1911+
t.Fatal(err)
1912+
return nil, nil
1913+
case <-time.After(time.Second):
1914+
clientConn.Close()
1915+
t.Fatal("timeout waiting for connection")
1916+
return nil, nil
1917+
}
1918+
}
1919+
1920+
// TestSessionCloseDeadlock is a regression test for a deadlock on closing connections.
1921+
// See: https://github.com/libp2p/go-yamux/issues/129 for details
1922+
func TestSessionCloseDeadlock(t *testing.T) {
1923+
const n = 10
1924+
const numWrites = 1000
1925+
var closeWG sync.WaitGroup
1926+
closeWG.Add(numWrites * n)
1927+
for i := 0; i < n; i++ {
1928+
go func() {
1929+
conn1, conn2 := testTCPConn(t)
1930+
conf := DefaultConfig()
1931+
conf.LogOutput = io.Discard
1932+
client, err := Client(conn1, conf, nil)
1933+
require.NoError(t, err)
1934+
defer client.Close()
1935+
1936+
conf = DefaultConfig()
1937+
conf.LogOutput = io.Discard
1938+
server, err := Server(conn2, conf, nil)
1939+
require.NoError(t, err)
1940+
defer server.Close()
1941+
1942+
// Create a stream from client
1943+
stream, err := client.OpenStream(context.Background())
1944+
require.NoError(t, err)
1945+
defer stream.Close()
1946+
1947+
// Accept the stream on server side
1948+
serverStream, err := server.AcceptStream()
1949+
require.NoError(t, err)
1950+
defer serverStream.Close()
1951+
1952+
// Send an incomplete dataframe to the server to ensure that the recvloop is reading the rest
1953+
// of the message.
1954+
buf := make([]byte, 1<<10)
1955+
hdr := encode(typeData, serverStream.sendFlags(), serverStream.id, uint32(len(buf))+1000)
1956+
err = server.sendMsg(hdr, buf, nil, true)
1957+
require.NoError(t, err)
1958+
1959+
time.Sleep(1 * time.Second) // Let the read loop block on reading rest of the data frame
1960+
1961+
// Make many writes so that these writes takes up all the buffer space in the channel
1962+
// and the recvLoop deadlocks because it's unable to send the goAwayProtoErr
1963+
var writeWG sync.WaitGroup
1964+
writeWG.Add(numWrites)
1965+
for i := 0; i < numWrites; i++ {
1966+
go func() {
1967+
defer closeWG.Done()
1968+
buf := make([]byte, 1) // make small writes so that we queue up the channel
1969+
writeWG.Done() // This is used to trigger concurrent streamWrite with CloseWrite.
1970+
stream.Write(buf)
1971+
}()
1972+
}
1973+
// Wait for the client to exit
1974+
writeWG.Wait()
1975+
// Close the Write half of the client conn so that the send loop exits first.
1976+
err = conn1.(*net.TCPConn).CloseWrite()
1977+
require.NoError(t, err)
1978+
}()
1979+
}
1980+
closeWG.Wait()
1981+
}

stream.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ START:
182182

183183
// Send the header
184184
hdr = encode(typeData, flags, s.id, max)
185-
if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait()); err != nil {
185+
if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait(), true); err != nil {
186186
return 0, err
187187
}
188188

@@ -241,21 +241,21 @@ func (s *Stream) sendWindowUpdate(deadline <-chan struct{}) error {
241241

242242
s.epochStart = now
243243
hdr := encode(typeWindowUpdate, flags, s.id, delta)
244-
return s.session.sendMsg(hdr, nil, deadline)
244+
return s.session.sendMsg(hdr, nil, deadline, true)
245245
}
246246

247247
// sendClose is used to send a FIN
248248
func (s *Stream) sendClose() error {
249249
flags := s.sendFlags()
250250
flags |= flagFIN
251251
hdr := encode(typeWindowUpdate, flags, s.id, 0)
252-
return s.session.sendMsg(hdr, nil, nil)
252+
return s.session.sendMsg(hdr, nil, nil, true)
253253
}
254254

255255
// sendReset is used to send a RST
256256
func (s *Stream) sendReset(errCode uint32) error {
257257
hdr := encode(typeWindowUpdate, flagRST, s.id, errCode)
258-
return s.session.sendMsg(hdr, nil, nil)
258+
return s.session.sendMsg(hdr, nil, nil, true)
259259
}
260260

261261
// Reset resets the stream (forcibly closes the stream)

0 commit comments

Comments
 (0)