diff --git a/apps/nsqd/main.go b/apps/nsqd/main.go index cd2121376..791a21f5c 100644 --- a/apps/nsqd/main.go +++ b/apps/nsqd/main.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "log" "math/rand" "os" "os/signal" @@ -66,6 +67,10 @@ func (p *program) Start() error { options.Resolve(opts, flagSet, cfg) + if err := opts.Validate(); err != nil { + log.Fatal(err) + } + nsqd, err := nsqd.New(opts) if err != nil { logFatal("failed to instantiate nsqd - %s", err) diff --git a/apps/nsqd/options.go b/apps/nsqd/options.go index 947942349..97dc2a538 100644 --- a/apps/nsqd/options.go +++ b/apps/nsqd/options.go @@ -113,6 +113,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet { logLevel := opts.LogLevel flagSet.Var(&logLevel, "log-level", "set log verbosity: debug, info, warn, error, or fatal") flagSet.String("log-prefix", "[nsqd] ", "log message prefix") + flagSet.String("sigterm-mode", "shutdown", "action to take on a SIGTERM (shutdown, drain)") flagSet.Bool("verbose", false, "[deprecated] has no effect, use --log-level") flagSet.Int64("node-id", opts.ID, "unique part for message IDs, (int) in range [0,1024) (default is hash of hostname)") diff --git a/internal/http_api/topic_channel_args.go b/internal/http_api/topic_channel_args.go index 113e02538..7a444bd6d 100644 --- a/internal/http_api/topic_channel_args.go +++ b/internal/http_api/topic_channel_args.go @@ -10,24 +10,40 @@ type getter interface { Get(key string) (string, error) } -func GetTopicChannelArgs(rp getter) (string, string, error) { +// GetTopicArg returns the ?topic parameter +func GetTopicArg(rp getter) (string, error) { topicName, err := rp.Get("topic") if err != nil { - return "", "", errors.New("MISSING_ARG_TOPIC") + return "", errors.New("MISSING_ARG_TOPIC") } if !protocol.IsValidTopicName(topicName) { - return "", "", errors.New("INVALID_ARG_TOPIC") + return "", errors.New("INVALID_ARG_TOPIC") } + return topicName, nil +} +// GetChannelArg returns the ?channel parameter +func GetChannelArg(rp getter) (string, error) { channelName, err := rp.Get("channel") if err != nil { - return "", "", errors.New("MISSING_ARG_CHANNEL") + return "", errors.New("MISSING_ARG_CHANNEL") } if !protocol.IsValidChannelName(channelName) { - return "", "", errors.New("INVALID_ARG_CHANNEL") + return "", errors.New("INVALID_ARG_CHANNEL") } + return channelName, nil +} +func GetTopicChannelArgs(rp getter) (string, string, error) { + topicName, err := GetTopicArg(rp) + if err != nil { + return "", "", err + } + channelName, err := GetChannelArg(rp) + if err != nil { + return "", "", err + } return topicName, channelName, nil } diff --git a/internal/test/assertions.go b/internal/test/assertions.go index 330dfec4b..5e0429f59 100644 --- a/internal/test/assertions.go +++ b/internal/test/assertions.go @@ -1,6 +1,9 @@ package test import ( + "encoding/json" + "io/ioutil" + "net/http" "path/filepath" "reflect" "runtime" @@ -56,3 +59,18 @@ func isNil(object interface{}) bool { return false } + +func HTTPError(t *testing.T, resp *http.Response, code int, message string) { + type ErrMessage struct { + Message string `json:"message"` + } + + body, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() + t.Log(string(body)) + Equal(t, code, resp.StatusCode) + + var em ErrMessage + Nil(t, json.Unmarshal(body, &em)) + Equal(t, message, em.Message) +} diff --git a/internal/test/guids.go b/internal/test/guids.go new file mode 100644 index 000000000..a89562f6c --- /dev/null +++ b/internal/test/guids.go @@ -0,0 +1,15 @@ +package test + +import ( + "sync/atomic" +) + +// GUIDFactory is an atomic sequence that can be used for MessageID's for benchmarks +// to avoid ErrSequenceExpired when creating a large number of messages +type GUIDFactory struct { + n int64 +} + +func (gf *GUIDFactory) NextMessageID() int64 { + return atomic.AddInt64(&gf.n, 1) +} diff --git a/nsqadmin/http_test.go b/nsqadmin/http_test.go index 9dc20960c..b415d808a 100644 --- a/nsqadmin/http_test.go +++ b/nsqadmin/http_test.go @@ -159,7 +159,7 @@ func TestHTTPTopicsGET(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_topics_get" + strconv.Itoa(int(time.Now().Unix())) - nsqds[0].GetTopic(topicName) + nsqds[0].GetOrCreateTopic(topicName) time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -187,7 +187,7 @@ func TestHTTPTopicGET(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_topic_get" + strconv.Itoa(int(time.Now().Unix())) - nsqds[0].GetTopic(topicName) + nsqds[0].GetOrCreateTopic(topicName) time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -253,8 +253,9 @@ func TestHTTPChannelGET(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_channel_get" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetTopic(topicName) - topic.GetChannel("ch") + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) + topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -292,8 +293,9 @@ func TestHTTPNodesSingleGET(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_nodes_single_get" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetTopic(topicName) - topic.GetChannel("ch") + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) + topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -376,7 +378,7 @@ func TestHTTPTombstoneTopicNodePOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_tombstone_topic_node_post" + strconv.Itoa(int(time.Now().Unix())) - nsqds[0].GetTopic(topicName) + nsqds[0].GetOrCreateTopic(topicName) time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -399,7 +401,7 @@ func TestHTTPDeleteTopicPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_delete_topic_post" + strconv.Itoa(int(time.Now().Unix())) - nsqds[0].GetTopic(topicName) + nsqds[0].GetOrCreateTopic(topicName) time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -419,8 +421,9 @@ func TestHTTPDeleteChannelPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_delete_channel_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetTopic(topicName) - topic.GetChannel("ch") + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) + topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -440,7 +443,7 @@ func TestHTTPPauseTopicPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_pause_topic_post" + strconv.Itoa(int(time.Now().Unix())) - nsqds[0].GetTopic(topicName) + nsqds[0].GetOrCreateTopic(topicName) time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -474,8 +477,9 @@ func TestHTTPPauseChannelPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_pause_channel_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetTopic(topicName) - topic.GetChannel("ch") + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) + topic.GetOrCreateChannel("ch") time.Sleep(100 * time.Millisecond) client := http.Client{} @@ -509,7 +513,8 @@ func TestHTTPEmptyTopicPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_empty_topic_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetTopic(topicName) + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) topic.PutMessage(nsqd.NewMessage(nsqd.MessageID{}, []byte("1234"))) test.Equal(t, int64(1), topic.Depth()) time.Sleep(100 * time.Millisecond) @@ -537,8 +542,9 @@ func TestHTTPEmptyChannelPOST(t *testing.T) { defer nsqadmin1.Exit() topicName := "test_empty_channel_post" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqds[0].GetTopic(topicName) - channel := topic.GetChannel("ch") + topic, err := nsqds[0].GetOrCreateTopic(topicName) + test.Nil(t, err) + channel := topic.GetOrCreateChannel("ch") channel.PutMessage(nsqd.NewMessage(nsqd.MessageID{}, []byte("1234"))) time.Sleep(100 * time.Millisecond) diff --git a/nsqd/channel.go b/nsqd/channel.go index 8838c1e3d..b3d6db9d6 100644 --- a/nsqd/channel.go +++ b/nsqd/channel.go @@ -54,6 +54,7 @@ type Channel struct { // state tracking clients map[int64]Consumer paused int32 + isDraining int32 ephemeral bool deleteCallback func(*Channel) deleter sync.Once @@ -122,6 +123,20 @@ func NewChannel(topicName string, channelName string, nsqd *NSQD, return c } +// InFlightCount returns the number of messages that have been sent to a client, but not yet FIN or REQUEUE'd +func (c *Channel) InFlightCount() int64 { + c.inFlightMutex.Lock() + defer c.inFlightMutex.Unlock() + return int64(len(c.inFlightMessages)) +} + +// DeferredCount returns the number of messages that are queued in-memory for future delivery to a client +func (c *Channel) DeferredCount() int64 { + c.deferredMutex.Lock() + defer c.deferredMutex.Unlock() + return int64(len(c.deferredMessages)) +} + func (c *Channel) initPQ() { pqSize := int(math.Max(1, float64(c.nsqd.getOpts().MemQueueSize)/10)) @@ -136,6 +151,23 @@ func (c *Channel) initPQ() { c.deferredMutex.Unlock() } +// StartDraining starts draining a channel +// +// if there are no outstanding messages the channel is deleted immediately +// if there are messages outstanding it's deleted by FinishMessage +func (c *Channel) StartDraining() { + if !atomic.CompareAndSwapInt32(&c.isDraining, 0, 1) { + return + } + depth, inFlight, deferred := c.Depth(), c.InFlightCount(), c.DeferredCount() + c.nsqd.logf(LOG_INFO, "CHANNEL(%s): draining. depth:%d inFlight:%d deferred:%d", c.name, depth, inFlight, deferred) + // if we are empty delete + if depth+inFlight+deferred == 0 { + go c.deleter.Do(func() { c.deleteCallback(c) }) + } + // else cleanup happens on last FinishMessage +} + // Exiting returns a boolean indicating if this channel is closed/exiting func (c *Channel) Exiting() bool { return atomic.LoadInt32(&c.exitFlag) == 1 @@ -187,6 +219,9 @@ func (c *Channel) exit(deleted bool) error { return c.backend.Close() } +// Empty drains the channel of messages. +// +// If the channel is draining this will delete the channel func (c *Channel) Empty() error { c.Lock() defer c.Unlock() @@ -196,16 +231,27 @@ func (c *Channel) Empty() error { client.Empty() } +MemoryDrain: for { select { case <-c.memoryMsgChan: default: - goto finish + break MemoryDrain } } -finish: - return c.backend.Empty() + err := c.backend.Empty() + + // `backend.Empty` always results in an internal empty state (even if on-disk state might differ) + // so we want to logically continue to finish draining if applicable. + if atomic.LoadInt32(&c.isDraining) == 1 { + go c.deleter.Do(func() { c.deleteCallback(c) }) + } + + if err != nil { + return err + } + return nil } // flush persists all the messages in internal memory buffers to the backend @@ -346,6 +392,8 @@ func (c *Channel) TouchMessage(clientID int64, id MessageID, clientMsgTimeout ti } // FinishMessage successfully discards an in-flight message +// +// if this channel is draining and this is the last message this will initiate a channel deletion func (c *Channel) FinishMessage(clientID int64, id MessageID) error { msg, err := c.popInFlightMessage(clientID, id) if err != nil { @@ -355,6 +403,15 @@ func (c *Channel) FinishMessage(clientID int64, id MessageID) error { if c.e2eProcessingLatencyStream != nil { c.e2eProcessingLatencyStream.Insert(msg.Timestamp) } + + if atomic.LoadInt32(&c.isDraining) == 1 { + // if last msg, delete + depth, inFlight, deferred := c.Depth(), c.InFlightCount(), c.DeferredCount() + if depth+inFlight+deferred == 0 { + c.nsqd.logf(LOG_INFO, "CHANNEL(%s): draining complete", c.name) + go c.deleter.Do(func() { c.deleteCallback(c) }) + } + } return nil } diff --git a/nsqd/channel_test.go b/nsqd/channel_test.go index 96386f4cc..775519a85 100644 --- a/nsqd/channel_test.go +++ b/nsqd/channel_test.go @@ -21,8 +21,8 @@ func TestPutMessage(t *testing.T) { defer nsqd.Exit() topicName := "test_put_message" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel1 := topic.GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel1, _ := topic.GetOrCreateChannel("ch") var id MessageID msg := NewMessage(id, []byte("test")) @@ -42,9 +42,9 @@ func TestPutMessage2Chan(t *testing.T) { defer nsqd.Exit() topicName := "test_put_message_2chan" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel1 := topic.GetChannel("ch1") - channel2 := topic.GetChannel("ch2") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel1, _ := topic.GetOrCreateChannel("ch1") + channel2, _ := topic.GetOrCreateChannel("ch2") var id MessageID msg := NewMessage(id, []byte("test")) @@ -71,8 +71,8 @@ func TestInFlightWorker(t *testing.T) { defer nsqd.Exit() topicName := "test_in_flight_worker" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") for i := 0; i < count; i++ { msg := NewMessage(topic.GenerateID(), []byte("test")) @@ -112,8 +112,8 @@ func TestChannelEmpty(t *testing.T) { defer nsqd.Exit() topicName := "test_channel_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") msgs := make([]*Message, 0, 25) for i := 0; i < 25; i++ { @@ -148,8 +148,8 @@ func TestChannelEmptyConsumer(t *testing.T) { defer conn.Close() topicName := "test_channel_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") client := newClientV2(0, conn, nsqd) client.SetReadyCount(25) err := channel.AddClient(client.ID, client) @@ -186,8 +186,8 @@ func TestMaxChannelConsumers(t *testing.T) { defer conn.Close() topicName := "test_max_channel_consumers" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") client1 := newClientV2(1, conn, nsqd) client1.SetReadyCount(25) @@ -209,9 +209,9 @@ func TestChannelHealth(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel := topic.GetChannel("channel") + channel, _ := topic.GetOrCreateChannel("channel") channel.backend = &errorBackendQueue{} @@ -248,3 +248,42 @@ func TestChannelHealth(t *testing.T) { resp.Body.Close() test.Equal(t, "OK", string(body)) } + +// TestChannelDraining ensures a channel with an outstanding message is deleted after message consumption is finished +func TestChannelDraining(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + _, _, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + topicName := "test_drain_channel" + strconv.Itoa(int(time.Now().Unix())) + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel1, _ := topic.GetOrCreateChannel("ch") + + msg := NewMessage(topic.GenerateID(), []byte("test")) + topic.PutMessage(msg) + + msg2 := NewMessage(topic.GenerateID(), []byte("test2")) + topic.PutMessage(msg2) + + outputMsg := <-channel1.memoryMsgChan + channel1.StartInFlightTimeout(outputMsg, 0, opts.MsgTimeout) + channel1.FinishMessage(0, outputMsg.ID) + test.Equal(t, msg.ID, outputMsg.ID) + test.Equal(t, msg.Body, outputMsg.Body) + + // one message should be left. drain the channel + channel1.StartDraining() + + test.Equal(t, int64(1), channel1.Depth()) + outputMsg = <-channel1.memoryMsgChan + channel1.StartInFlightTimeout(outputMsg, 0, opts.MsgTimeout) + channel1.FinishMessage(0, outputMsg.ID) + test.Equal(t, msg2.ID, outputMsg.ID) + test.Equal(t, msg2.Body, outputMsg.Body) + + time.Sleep(time.Millisecond) + c, _ := topic.GetExistingChannel("ch") + test.Nil(t, c) +} diff --git a/nsqd/http.go b/nsqd/http.go index f0db91a2c..ae9572a34 100644 --- a/nsqd/http.go +++ b/nsqd/http.go @@ -67,6 +67,7 @@ func newHTTPServer(nsqd *NSQD, tlsEnabled bool, tlsRequired bool) *httpServer { router.Handle("POST", "/topic/empty", http_api.Decorate(s.doEmptyTopic, log, http_api.V1)) router.Handle("POST", "/topic/pause", http_api.Decorate(s.doPauseTopic, log, http_api.V1)) router.Handle("POST", "/topic/unpause", http_api.Decorate(s.doPauseTopic, log, http_api.V1)) + router.Handle("POST", "/topic/drain", http_api.Decorate(s.doDrainTopic, log, http_api.V1)) router.Handle("POST", "/channel/create", http_api.Decorate(s.doCreateChannel, log, http_api.V1)) router.Handle("POST", "/channel/delete", http_api.Decorate(s.doDeleteChannel, log, http_api.V1)) router.Handle("POST", "/channel/empty", http_api.Decorate(s.doEmptyChannel, log, http_api.V1)) @@ -74,6 +75,8 @@ func newHTTPServer(nsqd *NSQD, tlsEnabled bool, tlsRequired bool) *httpServer { router.Handle("POST", "/channel/unpause", http_api.Decorate(s.doPauseChannel, log, http_api.V1)) router.Handle("GET", "/config/:opt", http_api.Decorate(s.doConfig, log, http_api.V1)) router.Handle("PUT", "/config/:opt", http_api.Decorate(s.doConfig, log, http_api.V1)) + router.Handle("PUT", "/state/drain", http_api.Decorate(s.startDraining, log, http_api.V1)) + router.Handle("PUT", "/state/shutdown", http_api.Decorate(s.shutdown, log, http_api.V1)) // debug router.HandlerFunc("GET", "/debug/pprof/", pprof.Index) @@ -142,7 +145,27 @@ func (s *httpServer) doInfo(w http.ResponseWriter, req *http.Request, ps httprou }, nil } -func (s *httpServer) getExistingTopicFromQuery(req *http.Request) (*http_api.ReqParams, *Topic, string, error) { +func (s *httpServer) getExistingTopicFromQuery(req *http.Request) (*http_api.ReqParams, *Topic, error) { + reqParams, err := http_api.NewReqParams(req) + if err != nil { + s.nsqd.logf(LOG_ERROR, "failed to parse request params - %s", err) + return nil, nil, http_api.Err{400, "INVALID_REQUEST"} + } + + topicName, err := http_api.GetTopicArg(reqParams) + if err != nil { + return nil, nil, http_api.Err{400, err.Error()} + } + + topic, err := s.nsqd.GetExistingTopic(topicName) + if err != nil { + return nil, nil, http_api.Err{404, "TOPIC_NOT_FOUND"} + } + + return reqParams, topic, err +} + +func (s *httpServer) getExistingTopicChannelFromQuery(req *http.Request) (*http_api.ReqParams, *Topic, string, error) { reqParams, err := http_api.NewReqParams(req) if err != nil { s.nsqd.logf(LOG_ERROR, "failed to parse request params - %s", err) @@ -178,8 +201,12 @@ func (s *httpServer) getTopicFromQuery(req *http.Request) (url.Values, *Topic, e if !protocol.IsValidTopicName(topicName) { return nil, nil, http_api.Err{400, "INVALID_TOPIC"} } + topic, err := s.nsqd.GetOrCreateTopic(topicName) + if err != nil { + return nil, nil, http_api.Err{503, "EXITING"} + } - return reqParams, s.nsqd.GetTopic(topicName), nil + return reqParams, topic, nil } func (s *httpServer) doPUB(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { @@ -316,26 +343,10 @@ func (s *httpServer) doCreateTopic(w http.ResponseWriter, req *http.Request, ps } func (s *httpServer) doEmptyTopic(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { - reqParams, err := http_api.NewReqParams(req) - if err != nil { - s.nsqd.logf(LOG_ERROR, "failed to parse request params - %s", err) - return nil, http_api.Err{400, "INVALID_REQUEST"} - } - - topicName, err := reqParams.Get("topic") + _, topic, err := s.getExistingTopicFromQuery(req) if err != nil { - return nil, http_api.Err{400, "MISSING_ARG_TOPIC"} - } - - if !protocol.IsValidTopicName(topicName) { - return nil, http_api.Err{400, "INVALID_TOPIC"} - } - - topic, err := s.nsqd.GetExistingTopic(topicName) - if err != nil { - return nil, http_api.Err{404, "TOPIC_NOT_FOUND"} + return nil, err } - err = topic.Empty() if err != nil { return nil, http_api.Err{500, "INTERNAL_ERROR"} @@ -350,10 +361,9 @@ func (s *httpServer) doDeleteTopic(w http.ResponseWriter, req *http.Request, ps s.nsqd.logf(LOG_ERROR, "failed to parse request params - %s", err) return nil, http_api.Err{400, "INVALID_REQUEST"} } - - topicName, err := reqParams.Get("topic") + topicName, err := http_api.GetTopicArg(reqParams) if err != nil { - return nil, http_api.Err{400, "MISSING_ARG_TOPIC"} + return nil, http_api.Err{400, err.Error()} } err = s.nsqd.DeleteExistingTopic(topicName) @@ -365,20 +375,9 @@ func (s *httpServer) doDeleteTopic(w http.ResponseWriter, req *http.Request, ps } func (s *httpServer) doPauseTopic(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { - reqParams, err := http_api.NewReqParams(req) - if err != nil { - s.nsqd.logf(LOG_ERROR, "failed to parse request params - %s", err) - return nil, http_api.Err{400, "INVALID_REQUEST"} - } - - topicName, err := reqParams.Get("topic") - if err != nil { - return nil, http_api.Err{400, "MISSING_ARG_TOPIC"} - } - - topic, err := s.nsqd.GetExistingTopic(topicName) + _, topic, err := s.getExistingTopicFromQuery(req) if err != nil { - return nil, http_api.Err{404, "TOPIC_NOT_FOUND"} + return nil, err } if strings.Contains(req.URL.Path, "unpause") { @@ -399,17 +398,33 @@ func (s *httpServer) doPauseTopic(w http.ResponseWriter, req *http.Request, ps h return nil, nil } +// doDrainTopic initiates draining of a single topic. +// +// This is a noop if the topic is already draining or exiting +func (s *httpServer) doDrainTopic(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { + _, topic, err := s.getExistingTopicFromQuery(req) + if err != nil { + return nil, err + } + + topic.StartDraining() + return nil, nil +} + func (s *httpServer) doCreateChannel(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { - _, topic, channelName, err := s.getExistingTopicFromQuery(req) + _, topic, channelName, err := s.getExistingTopicChannelFromQuery(req) if err != nil { return nil, err } - topic.GetChannel(channelName) + _, err = topic.GetOrCreateChannel(channelName) + if err != nil { + return nil, http_api.Err{503, "EXITING"} + } return nil, nil } func (s *httpServer) doEmptyChannel(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { - _, topic, channelName, err := s.getExistingTopicFromQuery(req) + _, topic, channelName, err := s.getExistingTopicChannelFromQuery(req) if err != nil { return nil, err } @@ -428,7 +443,7 @@ func (s *httpServer) doEmptyChannel(w http.ResponseWriter, req *http.Request, ps } func (s *httpServer) doDeleteChannel(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { - _, topic, channelName, err := s.getExistingTopicFromQuery(req) + _, topic, channelName, err := s.getExistingTopicChannelFromQuery(req) if err != nil { return nil, err } @@ -442,7 +457,7 @@ func (s *httpServer) doDeleteChannel(w http.ResponseWriter, req *http.Request, p } func (s *httpServer) doPauseChannel(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { - _, topic, channelName, err := s.getExistingTopicFromQuery(req) + _, topic, channelName, err := s.getExistingTopicChannelFromQuery(req) if err != nil { return nil, err } @@ -661,3 +676,22 @@ func getOptByCfgName(opts interface{}, name string) (interface{}, bool) { } return nil, false } + +func (s *httpServer) startDraining(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { + go func() { + // in some cases StartDraining results in an exit immediately + // allow this API call to respond before exiting + time.Sleep(time.Millisecond) + s.nsqd.StartDraining() + }() + return nil, nil +} + +func (s *httpServer) shutdown(w http.ResponseWriter, req *http.Request, ps httprouter.Params) (interface{}, error) { + go func() { + // allow this API call to respond before exiting + time.Sleep(time.Millisecond) + s.nsqd.Exit() + }() + return nil, nil +} diff --git a/nsqd/http_test.go b/nsqd/http_test.go index c6cb3df14..94ac5e580 100644 --- a/nsqd/http_test.go +++ b/nsqd/http_test.go @@ -12,6 +12,7 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" "testing" "time" @@ -45,7 +46,7 @@ func TestHTTPpub(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) buf := bytes.NewBuffer([]byte("test message")) url := fmt.Sprintf("http://%s/pub?topic=%s", httpAddr, topicName) @@ -68,7 +69,7 @@ func TestHTTPpubEmpty(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) buf := bytes.NewBuffer([]byte("")) url := fmt.Sprintf("http://%s/pub?topic=%s", httpAddr, topicName) @@ -92,7 +93,7 @@ func TestHTTPmpub(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := []byte("test message") msgs := make([][]byte, 4) @@ -121,7 +122,7 @@ func TestHTTPmpubEmpty(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := []byte("test message") msgs := make([][]byte, 4) @@ -152,7 +153,7 @@ func TestHTTPmpubBinary(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub_bin" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) mpub := make([][]byte, 5) for i := range mpub { @@ -181,7 +182,7 @@ func TestHTTPmpubForNonNormalizedBinaryParam(t *testing.T) { defer nsqd.Exit() topicName := "test_http_mpub_bin" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) mpub := make([][]byte, 5) for i := range mpub { @@ -210,8 +211,8 @@ func TestHTTPpubDefer(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub_defer" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - ch := topic.GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") buf := bytes.NewBuffer([]byte("test message")) url := fmt.Sprintf("http://%s/pub?topic=%s&defer=%d", httpAddr, topicName, 1000) @@ -241,7 +242,7 @@ func TestHTTPSRequire(t *testing.T) { defer nsqd.Exit() topicName := "test_http_pub_req" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) buf := bytes.NewBuffer([]byte("test message")) url := fmt.Sprintf("http://%s/pub?topic=%s", httpAddr, topicName) @@ -288,7 +289,7 @@ func TestHTTPSRequireVerify(t *testing.T) { httpsAddr := nsqd.httpsListener.Addr().(*net.TCPAddr) topicName := "test_http_pub_req_verf" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) // no cert buf := bytes.NewBuffer([]byte("test message")) @@ -352,7 +353,7 @@ func TestTLSRequireVerifyExceptHTTP(t *testing.T) { defer nsqd.Exit() topicName := "test_http_req_verf_except_http" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) // no cert buf := bytes.NewBuffer([]byte("test message")) @@ -404,33 +405,15 @@ func TestHTTPV1TopicChannel(t *testing.T) { test.Nil(t, err) test.NotNil(t, channel) - em := ErrMessage{} - url = fmt.Sprintf("http://%s/topic/pause", httpAddr) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 400, resp.StatusCode) - test.Equal(t, "Bad Request", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "MISSING_ARG_TOPIC", em.Message) + test.HTTPError(t, resp, 400, "MISSING_ARG_TOPIC") url = fmt.Sprintf("http://%s/topic/pause?topic=%s", httpAddr, topicName+"abc") resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 404, resp.StatusCode) - test.Equal(t, "Not Found", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "TOPIC_NOT_FOUND", em.Message) + test.HTTPError(t, resp, 404, "TOPIC_NOT_FOUND") url = fmt.Sprintf("http://%s/topic/pause?topic=%s", httpAddr, topicName) resp, err = http.Post(url, "application/json", nil) @@ -691,42 +674,24 @@ func TestDeleteTopic(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - em := ErrMessage{} - url := fmt.Sprintf("http://%s/topic/delete", httpAddr) resp, err := http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 400, resp.StatusCode) - test.Equal(t, "Bad Request", http.StatusText(resp.StatusCode)) - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "MISSING_ARG_TOPIC", em.Message) + test.HTTPError(t, resp, 400, "MISSING_ARG_TOPIC") topicName := "test_http_delete_topic" + strconv.Itoa(int(time.Now().Unix())) url = fmt.Sprintf("http://%s/topic/delete?topic=%s", httpAddr, topicName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 404, resp.StatusCode) - test.Equal(t, "Not Found", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "TOPIC_NOT_FOUND", em.Message) + test.HTTPError(t, resp, 404, "TOPIC_NOT_FOUND") - nsqd.GetTopic(topicName) + nsqd.GetOrCreateTopic(topicName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) test.Equal(t, 200, resp.StatusCode) - body, _ = ioutil.ReadAll(resp.Body) + body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() t.Logf("%s", body) @@ -740,55 +705,29 @@ func TestEmptyTopic(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - em := ErrMessage{} - url := fmt.Sprintf("http://%s/topic/empty", httpAddr) resp, err := http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 400, resp.StatusCode) - test.Equal(t, "Bad Request", http.StatusText(resp.StatusCode)) - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "MISSING_ARG_TOPIC", em.Message) + test.HTTPError(t, resp, 400, "MISSING_ARG_TOPIC") topicName := "test_http_empty_topic" + strconv.Itoa(int(time.Now().Unix())) url = fmt.Sprintf("http://%s/topic/empty?topic=%s", httpAddr, topicName+"$") resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 400, resp.StatusCode) - test.Equal(t, "Bad Request", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "INVALID_TOPIC", em.Message) + test.HTTPError(t, resp, 400, "INVALID_ARG_TOPIC") url = fmt.Sprintf("http://%s/topic/empty?topic=%s", httpAddr, topicName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 404, resp.StatusCode) - test.Equal(t, "Not Found", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "TOPIC_NOT_FOUND", em.Message) + test.HTTPError(t, resp, 404, "TOPIC_NOT_FOUND") - nsqd.GetTopic(topicName) + nsqd.GetOrCreateTopic(topicName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) test.Equal(t, 200, resp.StatusCode) - body, _ = ioutil.ReadAll(resp.Body) + body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() t.Logf("%s", body) @@ -802,35 +741,17 @@ func TestEmptyChannel(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - em := ErrMessage{} - url := fmt.Sprintf("http://%s/channel/empty", httpAddr) resp, err := http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 400, resp.StatusCode) - test.Equal(t, "Bad Request", http.StatusText(resp.StatusCode)) - body, _ := ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "MISSING_ARG_TOPIC", em.Message) + test.HTTPError(t, resp, 400, "MISSING_ARG_TOPIC") topicName := "test_http_empty_channel" + strconv.Itoa(int(time.Now().Unix())) url = fmt.Sprintf("http://%s/channel/empty?topic=%s", httpAddr, topicName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 400, resp.StatusCode) - test.Equal(t, "Bad Request", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() - - t.Logf("%s", body) - err = json.Unmarshal(body, &em) - test.Nil(t, err) - test.Equal(t, "MISSING_ARG_CHANNEL", em.Message) + test.HTTPError(t, resp, 400, "MISSING_ARG_CHANNEL") channelName := "ch" @@ -838,40 +759,64 @@ func TestEmptyChannel(t *testing.T) { resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) test.Equal(t, 404, resp.StatusCode) - test.Equal(t, "Not Found", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) - resp.Body.Close() + test.HTTPError(t, resp, 404, "TOPIC_NOT_FOUND") - t.Logf("%s", body) - err = json.Unmarshal(body, &em) + topic, _ := nsqd.GetOrCreateTopic(topicName) + + url = fmt.Sprintf("http://%s/channel/empty?topic=%s&channel=%s", httpAddr, topicName, channelName) + resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, "TOPIC_NOT_FOUND", em.Message) + test.HTTPError(t, resp, 404, "CHANNEL_NOT_FOUND") - topic := nsqd.GetTopic(topicName) + topic.GetOrCreateChannel(channelName) - url = fmt.Sprintf("http://%s/channel/empty?topic=%s&channel=%s", httpAddr, topicName, channelName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, 404, resp.StatusCode) - test.Equal(t, "Not Found", http.StatusText(resp.StatusCode)) - body, _ = ioutil.ReadAll(resp.Body) + test.Equal(t, 200, resp.StatusCode) + body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() t.Logf("%s", body) - err = json.Unmarshal(body, &em) + test.Equal(t, []byte(""), body) +} + +func TestTopicDrain(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + _, httpAddr, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + url := fmt.Sprintf("http://%s/topic/drain", httpAddr) + resp, err := http.Post(url, "application/json", nil) test.Nil(t, err) - test.Equal(t, "CHANNEL_NOT_FOUND", em.Message) + test.HTTPError(t, resp, 400, "MISSING_ARG_TOPIC") - topic.GetChannel(channelName) + topicName := "test_http_topic_drain" + strconv.Itoa(int(time.Now().Unix())) + + url = fmt.Sprintf("http://%s/topic/drain?topic=%s", httpAddr, topicName+"$") + resp, err = http.Post(url, "application/json", nil) + test.HTTPError(t, resp, 400, "INVALID_ARG_TOPIC") + + url = fmt.Sprintf("http://%s/topic/drain?topic=%s", httpAddr, topicName) + resp, err = http.Post(url, "application/json", nil) + test.Nil(t, err) + test.HTTPError(t, resp, 404, "TOPIC_NOT_FOUND") + + nsqd.GetOrCreateTopic(topicName) resp, err = http.Post(url, "application/json", nil) test.Nil(t, err) test.Equal(t, 200, resp.StatusCode) - body, _ = ioutil.ReadAll(resp.Body) + body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() t.Logf("%s", body) test.Equal(t, []byte(""), body) + + time.Sleep(time.Millisecond * 100) + topic, _ := nsqd.GetExistingTopic(topicName) + test.Nil(t, topic) } func TestInfo(t *testing.T) { @@ -896,6 +841,40 @@ func TestInfo(t *testing.T) { test.Equal(t, version.Binary, info.Version) } +func TestShutdown(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + _, httpAddr, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + url := fmt.Sprintf("http://%s/state/shutdown", httpAddr) + req, err := http.NewRequest(http.MethodPut, url, nil) + test.Nil(t, err) + resp, err := http.DefaultClient.Do(req) + test.Nil(t, err) + test.Equal(t, 200, resp.StatusCode) + time.Sleep(10 * time.Millisecond) + test.Equal(t, int32(1), atomic.LoadInt32(&nsqd.isExiting)) +} + +func TestDrain(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + _, httpAddr, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + url := fmt.Sprintf("http://%s/state/drain", httpAddr) + req, err := http.NewRequest(http.MethodPut, url, nil) + test.Nil(t, err) + resp, err := http.DefaultClient.Do(req) + test.Nil(t, err) + test.Equal(t, 200, resp.StatusCode) + time.Sleep(10 * time.Millisecond) + test.Equal(t, int32(1), atomic.LoadInt32(&nsqd.isExiting)) +} + func BenchmarkHTTPpub(b *testing.B) { var wg sync.WaitGroup b.StopTimer() diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index 550647d6e..b7c047aad 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -48,10 +48,12 @@ type NSQD struct { opts atomic.Value - dl *dirlock.DirLock - isLoading int32 - errValue atomic.Value - startTime time.Time + dl *dirlock.DirLock + isLoading int32 + isDraining int32 + isExiting int32 + errValue atomic.Value + startTime time.Time topicMap map[string]*Topic @@ -327,7 +329,11 @@ func (n *NSQD) LoadMetadata() error { n.logf(LOG_WARN, "skipping creation of invalid topic %s", t.Name) continue } - topic := n.GetTopic(t.Name) + topic, err := n.GetOrCreateTopic(t.Name) + if err != nil { + n.logf(LOG_WARN, "skipping creation of topic, nsqd draining %s", t.Name) + continue + } if t.Paused { topic.Pause() } @@ -336,8 +342,8 @@ func (n *NSQD) LoadMetadata() error { n.logf(LOG_WARN, "skipping creation of invalid channel %s", c.Name) continue } - channel := topic.GetChannel(c.Name) - if c.Paused { + channel, err := topic.GetOrCreateChannel(c.Name) + if c.Paused && err != nil { channel.Pause() } } @@ -402,13 +408,21 @@ func (n *NSQD) PersistMetadata() error { return nil } -// TermSignal handles a SIGTERM calling Exit +// TermSignal handles a SIGTERM calling either StartDraining or Exit depending on settings. // This is a noop after first call func (n *NSQD) TermSignal() { - n.Exit() + if n.getOpts().SigtermMode == "drain" { + n.StartDraining() + } else { + n.Exit() + } } func (n *NSQD) Exit() { + if !atomic.CompareAndSwapInt32(&n.isExiting, 0, 1) { + // avoid double call + return + } if n.tcpListener != nil { n.tcpListener.Close() } @@ -444,15 +458,16 @@ func (n *NSQD) Exit() { n.ctxCancel() } -// GetTopic performs a thread safe operation -// to return a pointer to a Topic object (potentially new) -func (n *NSQD) GetTopic(topicName string) *Topic { +// GetOrCreateTopic performs a thread safe operation to get an existing topic or create a new one +// +// An error will be returned if nsqd is draining +func (n *NSQD) GetOrCreateTopic(topicName string) (*Topic, error) { // most likely, we already have this topic, so try read lock first. n.RLock() t, ok := n.topicMap[topicName] n.RUnlock() if ok { - return t + return t, nil } n.Lock() @@ -460,10 +475,26 @@ func (n *NSQD) GetTopic(topicName string) *Topic { t, ok = n.topicMap[topicName] if ok { n.Unlock() - return t + return t, nil } + if atomic.LoadInt32(&n.isDraining) == 1 { + // don't create new topics when nsqd is draining + return nil, errors.New("nsqd draining") + } + deleteCallback := func(t *Topic) { n.DeleteExistingTopic(t.name) + + // if nsqd is draining, check if this is removing the last topic + // and exit nsqd if it is + if atomic.LoadInt32(&t.isDraining) == 1 { + n.RLock() + topicCount := len(n.topicMap) + n.RUnlock() + if topicCount == 0 { + n.Exit() + } + } } t = NewTopic(topicName, n, deleteCallback) n.topicMap[topicName] = t @@ -475,7 +506,7 @@ func (n *NSQD) GetTopic(topicName string) *Topic { // if loading metadata at startup, no lookupd connections yet, topic started after load if atomic.LoadInt32(&n.isLoading) == 1 { - return t + return t, nil } // if using lookupd, make a blocking call to get the topics, and immediately create them. @@ -490,7 +521,7 @@ func (n *NSQD) GetTopic(topicName string) *Topic { if strings.HasSuffix(channelName, "#ephemeral") { continue // do not create ephemeral channel with no consumer client } - t.GetChannel(channelName) + t.GetOrCreateChannel(channelName) } } else if len(n.getOpts().NSQLookupdTCPAddresses) > 0 { n.logf(LOG_ERROR, "no available nsqlookupd to query for channels to pre-create for topic %s", t.name) @@ -498,7 +529,7 @@ func (n *NSQD) GetTopic(topicName string) *Topic { // now that all channels are added, start topic messagePump t.Start() - return t + return t, nil } // GetExistingTopic gets a topic only if it exists @@ -537,6 +568,33 @@ func (n *NSQD) DeleteExistingTopic(topicName string) error { return nil } +// StartDraining starts the process of draining all topics. If there are none +// Exit will be called immediately. +func (n *NSQD) StartDraining() { + if atomic.LoadInt32(&n.isLoading) == 1 { + return + } + if atomic.LoadInt32(&n.isExiting) == 1 { + return + } + if !atomic.CompareAndSwapInt32(&n.isDraining, 0, 1) { + return + } + + n.logf(LOG_INFO, "NSQ: draining") + + n.RLock() + for _, t := range n.topicMap { + t.StartDraining() + } + numberTopics := len(n.topicMap) + n.RUnlock() + if numberTopics == 0 { + n.Exit() + } + return +} + func (n *NSQD) Notify(v interface{}) { // since the in-memory metadata is incomplete, // should not persist metadata while loading it. diff --git a/nsqd/nsqd_test.go b/nsqd/nsqd_test.go index b39df44e8..dc81c9ab8 100644 --- a/nsqd/nsqd_test.go +++ b/nsqd/nsqd_test.go @@ -65,7 +65,7 @@ func TestStartup(t *testing.T) { err := nsqd.PersistMetadata() test.Nil(t, err) atomic.StoreInt32(&nsqd.isLoading, 1) - nsqd.GetTopic(topicName) // will not persist if `flagLoading` + nsqd.GetOrCreateTopic(topicName) // will not persist if `flagLoading` m, err := getMetadata(nsqd) test.Nil(t, err) test.Equal(t, 0, len(m.Topics)) @@ -73,14 +73,14 @@ func TestStartup(t *testing.T) { atomic.StoreInt32(&nsqd.isLoading, 0) body := make([]byte, 256) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < iterations; i++ { msg := NewMessage(topic.GenerateID(), body) topic.PutMessage(msg) } t.Logf("pulling from channel") - channel1 := topic.GetChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") t.Logf("read %d msgs", iterations/2) for i := 0; i < iterations/2; i++ { @@ -124,12 +124,12 @@ func TestStartup(t *testing.T) { doneExitChan <- 1 }() - topic = nsqd.GetTopic(topicName) + topic, _ = nsqd.GetOrCreateTopic(topicName) // should be empty; channel should have drained everything count := topic.Depth() test.Equal(t, int64(0), count) - channel1 = topic.GetChannel("ch1") + channel1, _ = topic.GetOrCreateChannel("ch1") for { if channel1.Depth() == int64(iterations/2) { @@ -176,8 +176,8 @@ func TestEphemeralTopicsAndChannels(t *testing.T) { }() body := []byte("an_ephemeral_message") - topic := nsqd.GetTopic(topicName) - ephemeralChannel := topic.GetChannel("ch1#ephemeral") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ephemeralChannel, _ := topic.GetOrCreateChannel("ch1#ephemeral") client := newClientV2(0, nil, nsqd) err := ephemeralChannel.AddClient(client.ID, client) test.Equal(t, err, nil) @@ -215,8 +215,8 @@ func TestPauseMetadata(t *testing.T) { // avoid concurrency issue of async PersistMetadata() calls atomic.StoreInt32(&nsqd.isLoading, 1) topicName := "pause_metadata" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("ch") atomic.StoreInt32(&nsqd.isLoading, 0) nsqd.PersistMetadata() diff --git a/nsqd/options.go b/nsqd/options.go index 20cf3ed3d..5e2be5d7f 100644 --- a/nsqd/options.go +++ b/nsqd/options.go @@ -3,10 +3,12 @@ package nsqd import ( "crypto/md5" "crypto/tls" + "fmt" "hash/crc32" "io" "log" "os" + "strings" "time" "github.com/nsqio/nsq/internal/lg" @@ -14,10 +16,11 @@ import ( type Options struct { // basic options - ID int64 `flag:"node-id" cfg:"id"` - LogLevel lg.LogLevel `flag:"log-level"` - LogPrefix string `flag:"log-prefix"` - Logger Logger + ID int64 `flag:"node-id" cfg:"id"` + LogLevel lg.LogLevel `flag:"log-level"` + LogPrefix string `flag:"log-prefix"` + Logger Logger + SigtermMode string `flag:"sigterm-mode"` // shutdown, drain TCPAddress string `flag:"tcp-address"` HTTPAddress string `flag:"http-address"` @@ -96,9 +99,10 @@ func NewOptions() *Options { defaultID := int64(crc32.ChecksumIEEE(h.Sum(nil)) % 1024) return &Options{ - ID: defaultID, - LogPrefix: "[nsqd] ", - LogLevel: lg.INFO, + ID: defaultID, + LogPrefix: "[nsqd] ", + LogLevel: lg.INFO, + SigtermMode: "shutdown", TCPAddress: "0.0.0.0:4150", HTTPAddress: "0.0.0.0:4151", @@ -153,3 +157,24 @@ func NewOptions() *Options { TLSMinVersion: tls.VersionTLS10, } } + +type ValidationErrors []string + +func (v ValidationErrors) Error() string { + return fmt.Sprintf("Invalid configuration:\n %s", strings.Join([]string(v), "\n ")) +} + +func (o *Options) Validate() error { + var msgs ValidationErrors + + switch o.SigtermMode { + case "shutdown", "drain": + default: + msgs = append(msgs, fmt.Sprintf("invalid sigterm-mode=%q (valid: \"shutdown\", \"drain\")", o.SigtermMode)) + } + + if len(msgs) != 0 { + return msgs + } + return nil +} diff --git a/nsqd/options_test.go b/nsqd/options_test.go new file mode 100644 index 000000000..45bea54a2 --- /dev/null +++ b/nsqd/options_test.go @@ -0,0 +1,41 @@ +package nsqd + +import ( + "fmt" + "testing" + + "github.com/nsqio/nsq/internal/test" +) + +func TestOptionsValidate(t *testing.T) { + type testCase struct { + options func() *Options + expected []string + } + tests := []testCase{ + { + options: func() *Options { + o := NewOptions() + o.SigtermMode = "a" + return o + }, + expected: []string{`invalid sigterm-mode="a" (valid: "shutdown", "drain")`}, + }, + } + for i, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := tc.options() + err := o.Validate() + if err == nil { + test.Equal(t, len(tc.expected), 0) + } else { + v := err.(ValidationErrors) + for n, m := range v { + t.Logf("[%d] %s", n, m) + } + test.Equal(t, tc.expected, []string(v)) + } + }) + } +} diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index ccac1b0f1..1c1d673e3 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -616,8 +616,21 @@ func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) { // Avoid adding a client to an ephemeral channel / topic which has started exiting. var channel *Channel for { - topic := p.nsqd.GetTopic(topicName) - channel = topic.GetChannel(channelName) + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { + // topic creation might be blocked because of draining + return nil, protocol.NewFatalClientErr(nil, "E_NSQD_DRAINING", + fmt.Sprintf("SUB create channel %s:%s failed. nsqd is draining", + topicName, channelName)) + } + channel, _ = topic.GetOrCreateChannel(channelName) + if err != nil { + // channel creation might be blocked because of draining + return nil, protocol.NewFatalClientErr(nil, "E_TOPIC_DRAINING", + fmt.Sprintf("SUB create channel %s:%s failed. Topic is draining with no messages left", + topicName, channelName)) + } + if err := channel.AddClient(client.ID, client); err != nil { return nil, protocol.NewFatalClientErr(nil, "E_TOO_MANY_CHANNEL_CONSUMERS", fmt.Sprintf("channel consumers for %s:%s exceeds limit of %d", @@ -801,7 +814,10 @@ func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) { return nil, err } - topic := p.nsqd.GetTopic(topicName) + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { + return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed. nsqd draining") + } msg := NewMessage(topic.GenerateID(), messageBody) err = topic.PutMessage(msg) if err != nil { @@ -830,7 +846,10 @@ func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) { return nil, err } - topic := p.nsqd.GetTopic(topicName) + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { + return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed. nsqd draining") + } bodyLen, err := readLen(client.Reader, client.lenSlice) if err != nil { @@ -917,7 +936,10 @@ func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) { return nil, err } - topic := p.nsqd.GetTopic(topicName) + topic, err := p.nsqd.GetOrCreateTopic(topicName) + if err != nil { + return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed. nsqd draining") + } msg := NewMessage(topic.GenerateID(), messageBody) msg.deferred = timeoutDuration err = topic.PutMessage(msg) diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index 84803f007..68d29537b 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -137,7 +137,7 @@ func TestBasicV2(t *testing.T) { defer nsqd.Exit() topicName := "test_v2" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -172,10 +172,10 @@ func TestMultipleConsumerV2(t *testing.T) { defer nsqd.Exit() topicName := "test_multiple_v2" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) - topic.GetChannel("ch1") - topic.GetChannel("ch2") + topic.GetOrCreateChannel("ch1") + topic.GetOrCreateChannel("ch2") topic.PutMessage(msg) for _, i := range []string{"1", "2"} { @@ -382,9 +382,9 @@ func TestPausing(t *testing.T) { _, err = nsq.Ready(1).WriteTo(conn) test.Nil(t, err) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) - channel := topic.GetChannel("ch") + channel, _ := topic.GetOrCreateChannel("ch") topic.PutMessage(msg) // receive the first message via the client, finish it, and send new RDY @@ -590,7 +590,8 @@ func TestDPUB(t *testing.T) { time.Sleep(25 * time.Millisecond) - ch := nsqd.GetTopic(topicName).GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") ch.deferredMutex.Lock() numDef := len(ch.deferredMessages) ch.deferredMutex.Unlock() @@ -624,8 +625,8 @@ func TestTouch(t *testing.T) { identify(t, conn, nil, frameTypeResponse) sub(t, conn, topicName, "ch") - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -667,7 +668,7 @@ func TestMaxRdyCount(t *testing.T) { test.Nil(t, err) defer conn.Close() - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -743,7 +744,7 @@ func TestOutputBuffering(t *testing.T) { outputBufferSize := 256 * 1024 outputBufferTimeout := 500 - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), make([]byte, outputBufferSize-1024)) topic.PutMessage(msg) @@ -1139,7 +1140,7 @@ func TestSnappy(t *testing.T) { _, err = nsq.Ready(1).WriteTo(rw) test.Nil(t, err) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), msgBody) topic.PutMessage(msg) @@ -1232,12 +1233,12 @@ func TestSampling(t *testing.T) { test.Equal(t, int32(sampleRate), r.SampleRate) topicName := "test_sampling" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < num; i++ { msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) } - channel := topic.GetChannel("ch") + channel, _ := topic.GetOrCreateChannel("ch") // let the topic drain into the channel time.Sleep(50 * time.Millisecond) @@ -1336,8 +1337,8 @@ func TestClientMsgTimeout(t *testing.T) { defer nsqd.Exit() topicName := "test_cmsg_timeout" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - ch := topic.GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), make([]byte, 100)) topic.PutMessage(msg) @@ -1429,8 +1430,8 @@ func TestReqTimeoutRange(t *testing.T) { identify(t, conn, nil, frameTypeResponse) sub(t, conn, topicName, "ch") - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("ch") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("ch") msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) @@ -1796,12 +1797,12 @@ func benchmarkProtocolV2Sub(b *testing.B, size int) { defer os.RemoveAll(opts.DataPath) msg := make([]byte, size) topicName := "bench_v2_sub" + strconv.Itoa(b.N) + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < b.N; i++ { msg := NewMessage(topic.GenerateID(), msg) topic.PutMessage(msg) } - topic.GetChannel("ch") + topic.GetOrCreateChannel("ch") b.SetBytes(int64(len(msg))) goChan := make(chan int) rdyChan := make(chan int) @@ -1896,12 +1897,12 @@ func benchmarkProtocolV2MultiSub(b *testing.B, num int) { workers := runtime.GOMAXPROCS(0) for i := 0; i < num; i++ { topicName := "bench_v2" + strconv.Itoa(b.N) + "_" + strconv.Itoa(i) + "_" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) for i := 0; i < b.N; i++ { msg := NewMessage(topic.GenerateID(), msg) topic.PutMessage(msg) } - topic.GetChannel("ch") + topic.GetOrCreateChannel("ch") for j := 0; j < workers; j++ { wg.Add(1) diff --git a/nsqd/stats.go b/nsqd/stats.go index 94a7be7d6..e497f88f5 100644 --- a/nsqd/stats.go +++ b/nsqd/stats.go @@ -60,19 +60,12 @@ type ChannelStats struct { } func NewChannelStats(c *Channel, clients []ClientStats, clientCount int) ChannelStats { - c.inFlightMutex.Lock() - inflight := len(c.inFlightMessages) - c.inFlightMutex.Unlock() - c.deferredMutex.Lock() - deferred := len(c.deferredMessages) - c.deferredMutex.Unlock() - return ChannelStats{ ChannelName: c.name, Depth: c.Depth(), BackendDepth: c.backend.Depth(), - InFlightCount: inflight, - DeferredCount: deferred, + InFlightCount: int(c.InFlightCount()), + DeferredCount: int(c.DeferredCount()), MessageCount: atomic.LoadUint64(&c.messageCount), RequeueCount: atomic.LoadUint64(&c.requeueCount), TimeoutCount: atomic.LoadUint64(&c.timeoutCount), diff --git a/nsqd/stats_test.go b/nsqd/stats_test.go index 065dfef6f..a1663ad5e 100644 --- a/nsqd/stats_test.go +++ b/nsqd/stats_test.go @@ -22,12 +22,12 @@ func TestStats(t *testing.T) { defer nsqd.Exit() topicName := "test_stats" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) msg := NewMessage(topic.GenerateID(), []byte("test body")) topic.PutMessage(msg) accompanyTopicName := "accompany_test_stats" + strconv.Itoa(int(time.Now().Unix())) - accompanyTopic := nsqd.GetTopic(accompanyTopicName) + accompanyTopic, _ := nsqd.GetOrCreateTopic(accompanyTopicName) msg = NewMessage(accompanyTopic.GenerateID(), []byte("accompany test body")) accompanyTopic.PutMessage(msg) @@ -126,8 +126,8 @@ func TestStatsChannelLocking(t *testing.T) { defer nsqd.Exit() topicName := "test_channel_empty" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) - channel := topic.GetChannel("channel") + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel("channel") var wg sync.WaitGroup diff --git a/nsqd/topic.go b/nsqd/topic.go index 76aad14bb..76e3dcdb3 100644 --- a/nsqd/topic.go +++ b/nsqd/topic.go @@ -30,6 +30,7 @@ type Topic struct { waitGroup util.WaitGroupWrapper exitFlag int32 idFactory *guidFactory + isDraining int32 ephemeral bool deleteCallback func(*Topic) @@ -94,20 +95,51 @@ func (t *Topic) Start() { } } +// StartDraining does a clean delete of a topic. +// +// - Puts of new messages will error +// - If messages are left channel draining will start after Empty() +// or when the last message is written to a channel (in messagePump) +// - If no messages are left, channels will start draining immediately +func (t *Topic) StartDraining() { + t.Lock() // block PutMessage + defer t.Unlock() + if t.Exiting() { + return + } + if !atomic.CompareAndSwapInt32(&t.isDraining, 0, 1) { + return + } + + msgsLeft := t.Depth() + t.nsqd.logf(LOG_INFO, "TOPIC(%s): draining. depth:%d channels:%d", t.name, msgsLeft, len(t.channelMap)) + + // if no outstanding messages, start channel drain + if msgsLeft == 0 { + for _, c := range t.channelMap { + c.StartDraining() + } + if len(t.channelMap) == 0 { + go t.deleter.Do(func() { t.deleteCallback(t) }) + } + } +} + // Exiting returns a boolean indicating if this topic is closed/exiting func (t *Topic) Exiting() bool { return atomic.LoadInt32(&t.exitFlag) == 1 } -// GetChannel performs a thread safe operation -// to return a pointer to a Channel object (potentially new) -// for the given Topic -func (t *Topic) GetChannel(channelName string) *Channel { +// GetOrCreateChannel performs a thread safe operation +// to return a Channel object (potentially new) +// +// The creation might fail if the topic is draining and no messages are outstanding +func (t *Topic) GetOrCreateChannel(channelName string) (*Channel, error) { t.Lock() - channel, isNew := t.getOrCreateChannel(channelName) + channel, isNew, err := t.getOrCreateChannel(channelName) t.Unlock() - if isNew { + if isNew && err != nil { // update messagePump state select { case t.channelUpdateChan <- 1: @@ -115,22 +147,40 @@ func (t *Topic) GetChannel(channelName string) *Channel { } } - return channel + return channel, err } -// this expects the caller to handle locking -func (t *Topic) getOrCreateChannel(channelName string) (*Channel, bool) { +// getOrCreateChannel expects the caller to handle locking +func (t *Topic) getOrCreateChannel(channelName string) (c *Channel, isNew bool, err error) { channel, ok := t.channelMap[channelName] if !ok { + if atomic.LoadInt32(&t.isDraining) == 1 { + // if this topic is draining, and there are no messages on the topic don't create a new channel + if t.Depth() == 0 { + return nil, false, errors.New("topic draining") + } + } + deleteCallback := func(c *Channel) { t.DeleteExistingChannel(c.name) + if atomic.LoadInt32(&t.isDraining) == 1 { + // if no channels left; no msgs left delete + t.RLock() + numChannels := len(t.channelMap) + depth := t.Depth() + t.nsqd.logf(LOG_INFO, "TOPIC(%s): deleting channel(%s). Draining status: channels:%d topic depth:%d", t.name, c.name, numChannels, depth) + t.RUnlock() + if numChannels == 0 && depth == 0 { + go t.deleter.Do(func() { t.deleteCallback(t) }) + } + } } channel = NewChannel(t.name, channelName, t.nsqd, deleteCallback) t.channelMap[channelName] = channel t.nsqd.logf(LOG_INFO, "TOPIC(%s): new channel(%s)", t.name, channel.name) - return channel, true + return channel, true, nil } - return channel, false + return channel, false, nil } func (t *Topic) GetExistingChannel(channelName string) (*Channel, error) { @@ -182,6 +232,9 @@ func (t *Topic) PutMessage(m *Message) error { if atomic.LoadInt32(&t.exitFlag) == 1 { return errors.New("exiting") } + if atomic.LoadInt32(&t.isDraining) == 1 { + return errors.New("draining") + } err := t.put(m) if err != nil { return err @@ -198,6 +251,9 @@ func (t *Topic) PutMessages(msgs []*Message) error { if atomic.LoadInt32(&t.exitFlag) == 1 { return errors.New("exiting") } + if atomic.LoadInt32(&t.isDraining) == 1 { + return errors.New("draining") + } messageTotalBytes := 0 @@ -232,6 +288,7 @@ func (t *Topic) put(m *Message) error { return nil } +// Depth returns the number of unconsumed messages in the channel buffer or the disk backend func (t *Topic) Depth() int64 { return int64(len(t.memoryMsgChan)) + t.backend.Depth() } @@ -329,6 +386,19 @@ func (t *Topic) messagePump() { t.name, msg.ID, channel.name, err) } } + + // If in draining mode and we wrote a message to channels + // check if it was the last message on the topic (there are no more left) + // in which case we start draining each channel + if atomic.LoadInt32(&t.isDraining) == 1 { + if t.Depth() == 0 { + t.RLock() + for _, c := range t.channelMap { + c.StartDraining() + } + t.RUnlock() + } + } } exit: @@ -392,17 +462,36 @@ func (t *Topic) exit(deleted bool) error { return t.backend.Close() } +// Empty drains the topic of messages. +// +// If the topic is draining this will start draining each channel +// if there are no channels the topic will be deleted func (t *Topic) Empty() error { +MemoryDrain: for { select { case <-t.memoryMsgChan: default: - goto finish + break MemoryDrain } } -finish: - return t.backend.Empty() + err := t.backend.Empty() + if err != nil { + return err + } + + if atomic.LoadInt32(&t.isDraining) == 1 { + t.RLock() + for _, c := range t.channelMap { + c.StartDraining() + } + if len(t.channelMap) == 0 { + go t.deleter.Do(func() { t.deleteCallback(t) }) + } + t.RUnlock() + } + return nil } func (t *Topic) flush() error { diff --git a/nsqd/topic_test.go b/nsqd/topic_test.go index 0f0e1736f..42634f918 100644 --- a/nsqd/topic_test.go +++ b/nsqd/topic_test.go @@ -1,6 +1,7 @@ package nsqd import ( + "context" "errors" "fmt" "io/ioutil" @@ -8,6 +9,7 @@ import ( "os" "runtime" "strconv" + "sync" "testing" "time" @@ -21,14 +23,14 @@ func TestGetTopic(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic1 := nsqd.GetTopic("test") + topic1, _ := nsqd.GetOrCreateTopic("test") test.NotNil(t, topic1) test.Equal(t, "test", topic1.name) - topic2 := nsqd.GetTopic("test") + topic2, _ := nsqd.GetOrCreateTopic("test") test.Equal(t, topic1, topic2) - topic3 := nsqd.GetTopic("test2") + topic3, _ := nsqd.GetOrCreateTopic("test2") test.Equal(t, "test2", topic3.name) test.NotEqual(t, topic2, topic3) } @@ -40,13 +42,13 @@ func TestGetChannel(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel1 := topic.GetChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) test.Equal(t, "ch1", channel1.name) - channel2 := topic.GetChannel("ch2") + channel2, _ := topic.GetOrCreateChannel("ch2") test.Equal(t, channel1, topic.channelMap["ch1"]) test.Equal(t, channel2, topic.channelMap["ch2"]) @@ -73,7 +75,7 @@ func TestHealth(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") topic.backend = &errorBackendQueue{} msg := NewMessage(topic.GenerateID(), make([]byte, 100)) @@ -121,16 +123,16 @@ func TestDeletes(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel1 := topic.GetChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) err := topic.DeleteExistingChannel("ch1") test.Nil(t, err) test.Equal(t, 0, len(topic.channelMap)) - channel2 := topic.GetChannel("ch2") + channel2, _ := topic.GetOrCreateChannel("ch2") test.NotNil(t, channel2) err = nsqd.DeleteExistingTopic("test") @@ -146,9 +148,9 @@ func TestDeleteLast(t *testing.T) { defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - topic := nsqd.GetTopic("test") + topic, _ := nsqd.GetOrCreateTopic("test") - channel1 := topic.GetChannel("ch1") + channel1, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel1) err := topic.DeleteExistingChannel("ch1") @@ -170,11 +172,11 @@ func TestPause(t *testing.T) { defer nsqd.Exit() topicName := "test_topic_pause" + strconv.Itoa(int(time.Now().Unix())) - topic := nsqd.GetTopic(topicName) + topic, _ := nsqd.GetOrCreateTopic(topicName) err := topic.Pause() test.Nil(t, err) - channel := topic.GetChannel("ch1") + channel, _ := topic.GetOrCreateChannel("ch1") test.NotNil(t, channel) msg := NewMessage(topic.GenerateID(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) @@ -195,20 +197,54 @@ func TestPause(t *testing.T) { test.Equal(t, int64(1), channel.Depth()) } +func TestDrainEmpty(t *testing.T) { + opts := NewOptions() + opts.Logger = test.NewTestLogger(t) + _, _, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + topic, _ := nsqd.GetOrCreateTopic("drain_topic_empty") + + channel1, _ := topic.GetOrCreateChannel("ch1") + test.NotNil(t, channel1) + channel2, _ := topic.GetOrCreateChannel("ch2") + test.NotNil(t, channel2) + test.Equal(t, 2, len(topic.channelMap)) + + topic.StartDraining() + time.Sleep(time.Millisecond) + + // topic exiting is slow + for i := 0; i < 10; i++ { + t, _ := nsqd.GetExistingTopic("drain_topic_empty") + if t == nil { + break + } + time.Sleep(time.Millisecond * 50) + } + test.Equal(t, 0, len(topic.channelMap)) + + topic, _ = nsqd.GetExistingTopic("drain_topic_empty") + test.Nil(t, topic) +} + func BenchmarkTopicPut(b *testing.B) { b.StopTimer() topicName := "bench_topic_put" + strconv.Itoa(b.N) opts := NewOptions() opts.Logger = test.NewTestLogger(b) + opts.LogLevel = LOG_WARN opts.MemQueueSize = int64(b.N) _, _, nsqd := mustStartNSQD(opts) defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() + gf := &test.GUIDFactory{} b.StartTimer() for i := 0; i <= b.N; i++ { - topic := nsqd.GetTopic(topicName) - msg := NewMessage(topic.GenerateID(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) + topic, _ := nsqd.GetOrCreateTopic(topicName) + msg := NewMessage(guid(gf.NextMessageID()).Hex(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) topic.PutMessage(msg) } } @@ -219,16 +255,18 @@ func BenchmarkTopicToChannelPut(b *testing.B) { channelName := "bench" opts := NewOptions() opts.Logger = test.NewTestLogger(b) + opts.LogLevel = LOG_WARN opts.MemQueueSize = int64(b.N) _, _, nsqd := mustStartNSQD(opts) defer os.RemoveAll(opts.DataPath) defer nsqd.Exit() - channel := nsqd.GetTopic(topicName).GetChannel(channelName) + topic, _ := nsqd.GetOrCreateTopic(topicName) + channel, _ := topic.GetOrCreateChannel(channelName) + gf := &test.GUIDFactory{} b.StartTimer() for i := 0; i <= b.N; i++ { - topic := nsqd.GetTopic(topicName) - msg := NewMessage(topic.GenerateID(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) + msg := NewMessage(guid(gf.NextMessageID()).Hex(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) topic.PutMessage(msg) } @@ -239,3 +277,43 @@ func BenchmarkTopicToChannelPut(b *testing.B) { runtime.Gosched() } } + +func BenchmarkTopicMessagePump(b *testing.B) { + b.StopTimer() + topicName := "bench_topic_put_throughput" + strconv.Itoa(b.N) + opts := NewOptions() + opts.Logger = test.NewTestLogger(b) + opts.LogLevel = LOG_WARN + opts.MemQueueSize = int64(b.N) + _, _, nsqd := mustStartNSQD(opts) + defer os.RemoveAll(opts.DataPath) + defer nsqd.Exit() + + topic, _ := nsqd.GetOrCreateTopic(topicName) + ch, _ := topic.GetOrCreateChannel("ch") + ctx, cancel := context.WithCancel(context.Background()) + gf := &test.GUIDFactory{} + + var wg sync.WaitGroup + for i := 0; i < runtime.GOMAXPROCS(0); i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ch.memoryMsgChan: + case <-ctx.Done(): + return + } + } + }() + } + + b.StartTimer() + for i := 0; i <= b.N; i++ { + msg := NewMessage(guid(gf.NextMessageID()).Hex(), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) + topic.PutMessage(msg) + } + cancel() + wg.Wait() +}