Skip to content
This repository was archived by the owner on Jul 21, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions zk/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,14 @@ func WithMaxConnBufferSize(maxBufferSize int) connOption {
}

func (c *Conn) Close() {
rc, err := c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil)
if err != nil {
return
}
close(c.shouldQuit)

select {
case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil):
case <-rc:
case <-time.After(time.Second):
}
}
Expand Down Expand Up @@ -933,7 +937,13 @@ func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event {
return ch
}

func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response {
func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (<-chan response, error) {
select {
case <-c.shouldQuit:
return nil, ErrClosing
default:
}

rq := &request{
xid: c.nextXid(),
opcode: opcode,
Expand All @@ -942,13 +952,27 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv
recvChan: make(chan response, 1),
recvFunc: recvFunc,
}
c.sendChan <- rq
return rq.recvChan

select {
case c.sendChan <- rq:
return rq.recvChan, nil
case <-c.shouldQuit:
return nil, ErrClosing
}
}

func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) {
r := <-c.queueRequest(opcode, req, res, recvFunc)
return r.zxid, r.err
rc, err := c.queueRequest(opcode, req, res, recvFunc)
if err != nil {
return 0, err
}

select {
case r := <-rc:
return r.zxid, r.err
case <-c.shouldQuit:
return 0, ErrClosing
}
}

func (c *Conn) AddAuth(scheme string, auth []byte) error {
Expand Down
17 changes: 17 additions & 0 deletions zk/zk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,23 @@ func TestRequestFail(t *testing.T) {
}
}

func TestRequestFailAfterClosed(t *testing.T) {
ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "})
if err != nil {
t.Fatal(err)
}
defer ts.Stop()
zk, _, err := ts.ConnectAll()
if err != nil {
t.Fatalf("Connect returned error: %+v", err)
}
zk.Close()
_, _, err = zk.Get("/blah")
if err != ErrClosing {
t.Fatalf("unexpected err: %+v", err)
}
}

func TestSlowServer(t *testing.T) {
ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "})
if err != nil {
Expand Down