diff --git a/conn.go b/conn.go index 3848ab4a..ecbee898 100644 --- a/conn.go +++ b/conn.go @@ -219,7 +219,7 @@ var validReceivedCloseCodes = map[int]bool{ CloseTLSHandshake: false, } -func isValidReceivedCloseCode(code int) bool { +func isValidCloseCode(code int) bool { return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) } @@ -242,6 +242,7 @@ type Conn struct { conn net.Conn isServer bool subprotocol string + isClosed chan bool // Write fields mu chan bool // used as mutex to protect write to conn @@ -331,6 +332,28 @@ func (c *Conn) Close() error { return c.conn.Close() } + +// Shutdown sends a close frame to the peer and waits for close frame in resopnse. +// Shutdown assumes that the application is reading the connection in another +// goroutine and hence it does not try to read close frame itself +func (c *Conn) Shutdown(closeCode int, closeMessage string, timeout time.Duration) error { + if !isValidCloseCode(closeCode) { + // we do not shutdown connection + return errors.New("invalid close code received") + } + if !utf8.ValidString(closeMessage) { + return errors.New("invalid utf8 payload for shutdown message") + } + + message := FormatCloseMessage(closeCode, closeMessage) + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + select { + case <-time.After(timeout): // if nothing happens and we timeout + case <-c.isClosed: // if existing reader encounters close frame + } + return c.Close() +} + // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() @@ -496,6 +519,7 @@ func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and // PongMessage) are supported. func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter if err := c.beginMessage(&mw, messageType); err != nil { return nil, err @@ -898,11 +922,12 @@ func (c *Conn) advanceFrame() (int, error) { return noFrame, err } case CloseMessage: + c.isClosed <- true closeCode := CloseNoStatusReceived closeText := "" if len(payload) >= 2 { closeCode = int(binary.BigEndian.Uint16(payload)) - if !isValidReceivedCloseCode(closeCode) { + if !isValidCloseCode(closeCode) { return noFrame, c.handleProtocolError("invalid close code") } closeText = string(payload[2:])