Skip to content

Commit 099db8b

Browse files
committed
Add custom connection handler for websocket keepalive
Allows passing a ConnHandler to wsServeWithConnHandler for custom connection management, enabling active keepalive via ping messages. WsAnnouncementServe now uses this to send periodic pings.
1 parent d9bf299 commit 099db8b

File tree

3 files changed

+37
-7
lines changed

3 files changed

+37
-7
lines changed

v2/websocket.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package binance
22

33
import (
4+
"context"
45
"net/http"
56
"net/url"
67
"time"
@@ -28,7 +29,14 @@ func newWsConfig(endpoint string) *WsConfig {
2829
}
2930
}
3031

31-
var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (doneC, stopC chan struct{}, err error) {
32+
func wsServe(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (doneC, stopC chan struct{}, err error) {
33+
return wsServeWithConnHandler(cfg, handler, errHandler, nil)
34+
}
35+
36+
type ConnHandler func(context.Context, *websocket.Conn)
37+
38+
// WsServeWithConnHandler serves websocket with custom connection handler, useful for custom keepalive
39+
var wsServeWithConnHandler = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler, connHandler ConnHandler) (doneC, stopC chan struct{}, err error) {
3240
proxy := http.ProxyFromEnvironment
3341
if cfg.Proxy != nil {
3442
u, err := url.Parse(*cfg.Proxy)
@@ -62,6 +70,13 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
6270
keepAlive(c, WebsocketTimeout)
6371
}
6472

73+
// Custom connection handling, useful in active keepalive scenarios
74+
if connHandler != nil {
75+
ctx, cancel := context.WithCancel(context.Background())
76+
defer cancel()
77+
go connHandler(ctx, c)
78+
}
79+
6580
// Wait for the stopC channel to be closed. We do that in a
6681
// separate goroutine because ReadMessage is a blocking
6782
// operation.
@@ -88,6 +103,21 @@ var wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (don
88103
return
89104
}
90105

106+
func keepAliveHandler(duration time.Duration) ConnHandler {
107+
return func(ctx context.Context, c *websocket.Conn) {
108+
ticker := time.NewTicker(duration)
109+
for {
110+
select {
111+
case <-ctx.Done():
112+
ticker.Stop()
113+
return
114+
case <-ticker.C:
115+
c.WriteMessage(websocket.PingMessage, []byte{})
116+
}
117+
}
118+
}
119+
}
120+
91121
func keepAlive(c *websocket.Conn, timeout time.Duration) {
92122
ticker := time.NewTicker(timeout)
93123

@@ -137,7 +167,7 @@ var WsGetReadWriteConnection = func(cfg *WsConfig) (*websocket.Conn, error) {
137167
EnableCompression: false,
138168
}
139169

140-
c, _, err := Dialer.Dial(cfg.Endpoint, nil)
170+
c, _, err := Dialer.Dial(cfg.Endpoint, *cfg.Header)
141171
if err != nil {
142172
return nil, err
143173
}

v2/websocket_service.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ func WsAnnouncementServe(params WsAnnouncementParam, handler WsAnnouncementHandl
946946
json.Unmarshal([]byte(event.Data), &e)
947947
handler(e)
948948
}
949-
return wsServe(cfg, wsHandler, errHandler)
949+
return wsServeWithConnHandler(cfg, wsHandler, errHandler, keepAliveHandler(30*time.Second))
950950
}
951951

952952
// getWsApiEndpoint return the base endpoint of the API WS according the UseTestnet flag

v2/websocket_service_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99

1010
type websocketServiceTestSuite struct {
1111
baseTestSuite
12-
origWsServe func(*WsConfig, WsHandler, ErrHandler) (chan struct{}, chan struct{}, error)
12+
origWsServe func(*WsConfig, WsHandler, ErrHandler, ConnHandler) (chan struct{}, chan struct{}, error)
1313
serveCount int
1414
}
1515

@@ -18,16 +18,16 @@ func TestWebsocketService(t *testing.T) {
1818
}
1919

2020
func (s *websocketServiceTestSuite) SetupTest() {
21-
s.origWsServe = wsServe
21+
s.origWsServe = wsServeWithConnHandler
2222
}
2323

2424
func (s *websocketServiceTestSuite) TearDownTest() {
25-
wsServe = s.origWsServe
25+
wsServeWithConnHandler = s.origWsServe
2626
s.serveCount = 0
2727
}
2828

2929
func (s *websocketServiceTestSuite) mockWsServe(data []byte, err error) {
30-
wsServe = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler) (doneC, stopC chan struct{}, innerErr error) {
30+
wsServeWithConnHandler = func(cfg *WsConfig, handler WsHandler, errHandler ErrHandler, connHandler ConnHandler) (doneC, stopC chan struct{}, innerErr error) {
3131
s.serveCount++
3232
doneC = make(chan struct{})
3333
stopC = make(chan struct{})

0 commit comments

Comments
 (0)