diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5b4909a..a8a4515 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -63,7 +63,7 @@ jobs: outformat: out-format with: version: ${{ matrix.golangci }} - args: "--%outformat% colored-line-number" + args: "--%outformat% colored-line-number --timeout 2m" skip-pkg-cache: true skip-build-cache: true diff --git a/client.go b/client.go index 2f769f3..861ae37 100644 --- a/client.go +++ b/client.go @@ -211,7 +211,6 @@ func (c *Client) messageHandler() { for { if c.scanner.Scan() { line := c.scanner.Text() - //nolint: gocritic if line == "error id=0 msg=ok" { c.err <- nil } else if matches := respTrailerRe.FindStringSubmatch(line); len(matches) == 4 { @@ -229,11 +228,12 @@ func (c *Client) messageHandler() { } } else { err := c.scanErr() - c.err <- err - if errors.Is(err, io.ErrUnexpectedEOF) { + if errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) { close(c.disconnect) + c.err <- err return } + c.err <- err } } } @@ -253,9 +253,6 @@ func (c *Client) workHandler() { } func (c *Client) process(data string) { - if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil { - c.err <- err - } if _, err := c.conn.Write([]byte(data)); err != nil { c.err <- err } diff --git a/client_test.go b/client_test.go index 642dff2..3dc408c 100644 --- a/client_test.go +++ b/client_test.go @@ -128,7 +128,7 @@ func TestClientWriteFail(t *testing.T) { if !assert.NoError(t, err) { return } - assert.NoError(t, c.conn.(*legacyConnection).Conn.(*net.TCPConn).CloseWrite()) + assert.NoError(t, c.conn.(*legacyConnection).Conn.(*writeTimeoutConn).Conn.(*net.TCPConn).CloseWrite()) _, err = c.Exec("version") assert.Error(t, err) diff --git a/connection.go b/connection.go index 8f89c16..82e03d7 100644 --- a/connection.go +++ b/connection.go @@ -20,22 +20,45 @@ const ( DefaultSSHPort = 10022 ) +// writeTimeoutConn is a net.Conn that sets the write timeout on every call to Write(). +type writeTimeoutConn struct { + net.Conn + timeout time.Duration +} + +func (c *writeTimeoutConn) Write(p []byte) (n int, err error) { + if err = c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil { + return 0, fmt.Errorf("writeTimeoutConn: SetWriteDeadline: %w", err) + } + if n, err = c.Conn.Write(p); err != nil { + return n, fmt.Errorf("writeTimeoutConn: write: %w", err) + } + return n, nil +} + // legacyConnection is an insecure TCP connection. type legacyConnection struct { net.Conn } // Connect connects to the address with the given timeout. +// The timeout is used as dial and write timeout. func (c *legacyConnection) Connect(addr string, timeout time.Duration) error { addr, err := verifyAddr(addr, DefaultPort) if err != nil { return err } - c.Conn, err = net.DialTimeout("tcp", addr, timeout) + conn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return fmt.Errorf("legacy connection: dial: %w", err) } + + c.Conn = &writeTimeoutConn{ + Conn: conn, + timeout: timeout, + } + return nil } @@ -47,16 +70,23 @@ type sshConnection struct { } // Connect connects to the address with the given timeout and opens a new SSH channel with attached shell. +// The timeout is used as dial and write timeout. func (c *sshConnection) Connect(addr string, timeout time.Duration) error { addr, err := verifyAddr(addr, DefaultSSHPort) if err != nil { return err } - if c.Conn, err = net.DialTimeout("tcp", addr, timeout); err != nil { + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { return fmt.Errorf("ssh connection: dial: %w", err) } + c.Conn = &writeTimeoutConn{ + Conn: conn, + timeout: timeout, + } + clientConn, chans, reqs, err := ssh.NewClientConn(c.Conn, addr, c.config) if err != nil { return fmt.Errorf("ssh connecion: ssh client conn: %w", err)