diff --git a/zk/conn.go b/zk/conn.go index f79a51b3..76db254a 100644 --- a/zk/conn.go +++ b/zk/conn.go @@ -310,10 +310,14 @@ func WithMaxConnBufferSize(maxBufferSize int) connOption { } func (c *Conn) Close() { + rc, err := c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil) + if err != nil { + return + } close(c.shouldQuit) select { - case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): + case <-rc: case <-time.After(time.Second): } } @@ -933,7 +937,13 @@ func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { return ch } -func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { +func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (<-chan response, error) { + select { + case <-c.shouldQuit: + return nil, ErrClosing + default: + } + rq := &request{ xid: c.nextXid(), opcode: opcode, @@ -942,13 +952,27 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv recvChan: make(chan response, 1), recvFunc: recvFunc, } - c.sendChan <- rq - return rq.recvChan + + select { + case c.sendChan <- rq: + return rq.recvChan, nil + case <-c.shouldQuit: + return nil, ErrClosing + } } func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { - r := <-c.queueRequest(opcode, req, res, recvFunc) - return r.zxid, r.err + rc, err := c.queueRequest(opcode, req, res, recvFunc) + if err != nil { + return 0, err + } + + select { + case r := <-rc: + return r.zxid, r.err + case <-c.shouldQuit: + return 0, ErrClosing + } } func (c *Conn) AddAuth(scheme string, auth []byte) error { diff --git a/zk/zk_test.go b/zk/zk_test.go index c81ef9fb..76b12290 100644 --- a/zk/zk_test.go +++ b/zk/zk_test.go @@ -666,6 +666,23 @@ func TestRequestFail(t *testing.T) { } } +func TestRequestFailAfterClosed(t *testing.T) { + ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + defer ts.Stop() + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + zk.Close() + _, _, err = zk.Get("/blah") + if err != ErrClosing { + t.Fatalf("unexpected err: %+v", err) + } +} + func TestSlowServer(t *testing.T) { ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil {