diff --git a/close.go b/close.go index 2de1a5c2..fcc68065 100644 --- a/close.go +++ b/close.go @@ -231,12 +231,6 @@ func (c *Conn) waitGoroutines() error { t := time.NewTimer(time.Second * 15) defer t.Stop() - select { - case <-c.timeoutLoopDone: - case <-t.C: - return errors.New("failed to wait for timeoutLoop goroutine to exit") - } - c.closeReadMu.Lock() closeRead := c.closeReadCtx != nil c.closeReadMu.Unlock() diff --git a/conn.go b/conn.go index 5907bc81..09234871 100644 --- a/conn.go +++ b/conn.go @@ -51,9 +51,8 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer - readTimeout chan context.Context - writeTimeout chan context.Context - timeoutLoopDone chan struct{} + readTimeoutStop atomic.Pointer[func() bool] + writeTimeoutStop atomic.Pointer[func() bool] // Read state. readMu *mu @@ -113,10 +112,6 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - timeoutLoopDone: make(chan struct{}), - closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), onPingReceived: cfg.onPingReceived, @@ -144,8 +139,6 @@ func newConn(cfg connConfig) *Conn { c.close() }) - go c.timeoutLoop() - return c } @@ -175,27 +168,34 @@ func (c *Conn) close() error { return err } -func (c *Conn) timeoutLoop() { - defer close(c.timeoutLoopDone) +func (c *Conn) setupWriteTimeout(ctx context.Context) { + stop := context.AfterFunc(ctx, func() { + c.clearWriteTimeout() + c.close() + }) + swapTimeoutStop(&c.writeTimeoutStop, &stop) +} - readCtx := context.Background() - writeCtx := context.Background() +func (c *Conn) clearWriteTimeout() { + swapTimeoutStop(&c.writeTimeoutStop, nil) +} - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.close() - return - case <-writeCtx.Done(): - c.close() - return - } +func (c *Conn) setupReadTimeout(ctx context.Context) { + stop := context.AfterFunc(ctx, func() { + c.clearReadTimeout() + c.close() + }) + swapTimeoutStop(&c.readTimeoutStop, &stop) +} + +func (c *Conn) clearReadTimeout() { + swapTimeoutStop(&c.readTimeoutStop, nil) +} + +func swapTimeoutStop(p *atomic.Pointer[func() bool], newStop *func() bool) { + oldStop := p.Swap(newStop) + if oldStop != nil { + (*oldStop)() } } diff --git a/read.go b/read.go index aab9e141..995709e9 100644 --- a/read.go +++ b/read.go @@ -220,14 +220,15 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { // to be called after the read is done. It also returns an error if the // connection is closed. The reference to the error is used to assign // an error depending on if the connection closed or the context timed -// out during use. Typically the referenced error is a named return +// out during use. Typically, the referenced error is a named return // variable of the function calling this method. func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { select { case <-c.closed: return nil, net.ErrClosed - case c.readTimeout <- ctx: + default: } + c.setupReadTimeout(ctx) done := func() { select { @@ -235,7 +236,8 @@ func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) { if *err != nil { *err = net.ErrClosed } - case c.readTimeout <- context.Background(): + default: + c.clearReadTimeout() } if *err != nil && ctx.Err() != nil { *err = ctx.Err() @@ -280,7 +282,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error return n, fmt.Errorf("failed to read frame payload: %w", err) } - return n, err + return n, nil } func (c *Conn) handleControl(ctx context.Context, h header) (err error) { diff --git a/write.go b/write.go index 7104b227..b5dda23d 100644 --- a/write.go +++ b/write.go @@ -271,12 +271,14 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco select { case <-c.closed: return 0, net.ErrClosed - case c.writeTimeout <- ctx: + default: } + c.setupWriteTimeout(ctx) defer func() { select { case <-c.closed: - case c.writeTimeout <- context.Background(): + default: + c.clearWriteTimeout() } }()