diff --git a/client.go b/client.go index 101a2501..f42b388b 100644 --- a/client.go +++ b/client.go @@ -395,9 +395,16 @@ func (cli *Client) WaitForConnection(timeout time.Duration) bool { return true } +// Connect connects the client to the WhatsApp web websocket with the specified context. +// After connection, it will either authenticate if there's data in the device store, +// or emit a QREvent to set up a new link. +func (cli *Client) Connect() error { + return cli.ConnectWithTimeout(context.Background()) +} + // Connect connects the client to the WhatsApp web websocket. After connection, it will either // authenticate if there's data in the device store, or emit a QREvent to set up a new link. -func (cli *Client) Connect() error { +func (cli *Client) ConnectWithTimeout(timeCtx context.Context) error { cli.socketLock.Lock() defer cli.socketLock.Unlock() if cli.socket != nil { @@ -430,7 +437,7 @@ func (cli *Client) Connect() error { fs.HTTPHeaders.Set("Sec-Fetch-Mode", "websocket") fs.HTTPHeaders.Set("Sec-Fetch-Site", "cross-site") } - if err := fs.Connect(); err != nil { + if err := fs.Connect(timeCtx); err != nil { fs.Close(0) return err } else if err = cli.doHandshake(fs, *keys.NewKeyPair()); err != nil { diff --git a/message.go b/message.go index 4b08df16..2edf7db2 100644 --- a/message.go +++ b/message.go @@ -169,7 +169,6 @@ func (cli *Client) handlePlaintextMessage(info *types.MessageInfo, node *waBinar } } cli.dispatchEvent(evt.UnwrapRaw()) - return } func (cli *Client) decryptMessages(info *types.MessageInfo, node *waBinary.Node) { diff --git a/socket/framesocket.go b/socket/framesocket.go index 148c7008..52230d51 100644 --- a/socket/framesocket.go +++ b/socket/framesocket.go @@ -93,14 +93,14 @@ func (fs *FrameSocket) Close(code int) { } } -func (fs *FrameSocket) Connect() error { +func (fs *FrameSocket) Connect(timeCtx context.Context) error { fs.lock.Lock() defer fs.lock.Unlock() if fs.conn != nil { return ErrSocketAlreadyOpen } - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(timeCtx) fs.log.Debugf("Dialing %s", fs.URL) conn, _, err := fs.Dialer.Dial(fs.URL, fs.HTTPHeaders)