diff --git a/common/metrics/defs.go b/common/metrics/defs.go index 66253576413..b9322c896c9 100644 --- a/common/metrics/defs.go +++ b/common/metrics/defs.go @@ -1481,6 +1481,7 @@ const ( ShardDistributorStoreGetStateScope ShardDistributorStoreRecordHeartbeatScope ShardDistributorStoreSubscribeScope + ShardDistributorStoreSubscribeToAssignmentChangesScope // The scope for the shard distributor executor ShardDistributorExecutorScope @@ -2153,20 +2154,21 @@ var ScopeDefs = map[ServiceIdx]map[ScopeIdx]scopeDefinition{ DiagnosticsWorkflowScope: {operation: "DiagnosticsWorkflow"}, }, ShardDistributor: { - ShardDistributorGetShardOwnerScope: {operation: "GetShardOwner"}, - ShardDistributorWatchNamespaceStateScope: {operation: "WatchNamespaceState"}, - ShardDistributorHeartbeatScope: {operation: "ExecutorHeartbeat"}, - ShardDistributorAssignLoopScope: {operation: "ShardAssignLoop"}, - ShardDistributorExecutorScope: {operation: "Executor"}, - ShardDistributorStoreGetShardOwnerScope: {operation: "StoreGetShardOwner"}, - ShardDistributorStoreAssignShardScope: {operation: "StoreAssignShard"}, - ShardDistributorStoreAssignShardsScope: {operation: "StoreAssignShards"}, - ShardDistributorStoreDeleteExecutorsScope: {operation: "StoreDeleteExecutors"}, - ShardDistributorStoreDeleteShardStatsScope: {operation: "StoreDeleteShardStats"}, - ShardDistributorStoreGetHeartbeatScope: {operation: "StoreGetHeartbeat"}, - ShardDistributorStoreGetStateScope: {operation: "StoreGetState"}, - ShardDistributorStoreRecordHeartbeatScope: {operation: "StoreRecordHeartbeat"}, - ShardDistributorStoreSubscribeScope: {operation: "StoreSubscribe"}, + ShardDistributorGetShardOwnerScope: {operation: "GetShardOwner"}, + ShardDistributorWatchNamespaceStateScope: {operation: "WatchNamespaceState"}, + ShardDistributorHeartbeatScope: {operation: "ExecutorHeartbeat"}, + ShardDistributorAssignLoopScope: {operation: "ShardAssignLoop"}, + ShardDistributorExecutorScope: {operation: "Executor"}, + ShardDistributorStoreGetShardOwnerScope: {operation: "StoreGetShardOwner"}, + ShardDistributorStoreAssignShardScope: {operation: "StoreAssignShard"}, + ShardDistributorStoreAssignShardsScope: {operation: "StoreAssignShards"}, + ShardDistributorStoreDeleteExecutorsScope: {operation: "StoreDeleteExecutors"}, + ShardDistributorStoreDeleteShardStatsScope: {operation: "StoreDeleteShardStats"}, + ShardDistributorStoreGetHeartbeatScope: {operation: "StoreGetHeartbeat"}, + ShardDistributorStoreGetStateScope: {operation: "StoreGetState"}, + ShardDistributorStoreRecordHeartbeatScope: {operation: "StoreRecordHeartbeat"}, + ShardDistributorStoreSubscribeScope: {operation: "StoreSubscribe"}, + ShardDistributorStoreSubscribeToAssignmentChangesScope: {operation: "StoreSubscribeToAssignmentChanges"}, }, } diff --git a/service/sharddistributor/handler/handler.go b/service/sharddistributor/handler/handler.go index d71a834995e..08fe6b1680c 100644 --- a/service/sharddistributor/handler/handler.go +++ b/service/sharddistributor/handler/handler.go @@ -142,5 +142,78 @@ func (h *handlerImpl) assignEphemeralShard(ctx context.Context, namespace string } func (h *handlerImpl) WatchNamespaceState(request *types.WatchNamespaceStateRequest, server WatchNamespaceStateServer) error { - return fmt.Errorf("not implemented") + h.startWG.Wait() + + // Subscribe to state changes from storage + assignmentChangesChan, unSubscribe, err := h.storage.SubscribeToAssignmentChanges(server.Context(), request.Namespace) + defer unSubscribe() + if err != nil { + return fmt.Errorf("subscribe to namespace state: %w", err) + } + + // Send initial state immediately so client doesn't have to wait for first update + state, err := h.storage.GetState(server.Context(), request.Namespace) + if err != nil { + return fmt.Errorf("get initial state: %w", err) + } + response := toWatchNamespaceStateResponse(state) + if err := server.Send(response); err != nil { + return fmt.Errorf("send initial state: %w", err) + } + + // Stream subsequent updates + for { + select { + case <-server.Context().Done(): + return server.Context().Err() + case assignmentChanges, ok := <-assignmentChangesChan: + if !ok { + return fmt.Errorf("unexpected close of updates channel") + } + response := &types.WatchNamespaceStateResponse{ + Executors: make([]*types.ExecutorShardAssignment, 0, len(state.ShardAssignments)), + } + for executor, shardIDs := range assignmentChanges { + response.Executors = append(response.Executors, &types.ExecutorShardAssignment{ + ExecutorID: executor.ExecutorID, + AssignedShards: WrapShards(shardIDs), + Metadata: executor.Metadata, + }) + } + + err = server.Send(response) + if err != nil { + return fmt.Errorf("send response: %w", err) + } + } + } +} + +func toWatchNamespaceStateResponse(state *store.NamespaceState) *types.WatchNamespaceStateResponse { + response := &types.WatchNamespaceStateResponse{ + Executors: make([]*types.ExecutorShardAssignment, 0, len(state.ShardAssignments)), + } + + for executorID, assignment := range state.ShardAssignments { + // Extract shard IDs from the assigned shards map + shardIDs := make([]string, 0, len(assignment.AssignedShards)) + for shardID := range assignment.AssignedShards { + shardIDs = append(shardIDs, shardID) + } + + response.Executors = append(response.Executors, &types.ExecutorShardAssignment{ + ExecutorID: executorID, + AssignedShards: WrapShards(shardIDs), + Metadata: state.Executors[executorID].Metadata, + }) + } + return response +} + +func WrapShards(shardIDs []string) []*types.Shard { + shards := make([]*types.Shard, 0, len(shardIDs)) + for _, shardID := range shardIDs { + shards = append(shards, &types.Shard{ShardKey: shardID}) + } + return shards } diff --git a/service/sharddistributor/handler/handler_test.go b/service/sharddistributor/handler/handler_test.go index 901c42d41b4..a683168a8be 100644 --- a/service/sharddistributor/handler/handler_test.go +++ b/service/sharddistributor/handler/handler_test.go @@ -25,7 +25,9 @@ package handler import ( "context" "errors" + "sync" "testing" + "time" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -215,3 +217,82 @@ func TestGetShardOwner(t *testing.T) { }) } } + +func TestWatchNamespaceState(t *testing.T) { + ctrl := gomock.NewController(t) + logger := testlogger.New(t) + mockStorage := store.NewMockStore(ctrl) + mockServer := NewMockWatchNamespaceStateServer(ctrl) + + cfg := config.ShardDistribution{ + Namespaces: []config.Namespace{ + {Name: "test-ns", Type: config.NamespaceTypeFixed, ShardNum: 2}, + }, + } + + handler := &handlerImpl{ + logger: logger, + shardDistributionCfg: cfg, + storage: mockStorage, + startWG: sync.WaitGroup{}, + } + + t.Run("successful streaming", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + initialState := &store.NamespaceState{ + ShardAssignments: map[string]store.AssignedState{ + "executor-1": { + AssignedShards: map[string]*types.ShardAssignment{ + "shard-1": {}, + }, + }, + }, + } + + updatesChan := make(chan map[*store.ShardOwner][]string, 1) + unsubscribe := func() { close(updatesChan) } + + mockServer.EXPECT().Context().Return(ctx).AnyTimes() + mockStorage.EXPECT().GetState(gomock.Any(), "test-ns").Return(initialState, nil) + mockStorage.EXPECT().SubscribeToAssignmentChanges(gomock.Any(), "test-ns").Return(updatesChan, unsubscribe, nil) + + // Expect initial state send + mockServer.EXPECT().Send(gomock.Any()).DoAndReturn(func(resp *types.WatchNamespaceStateResponse) error { + require.Len(t, resp.Executors, 1) + require.Equal(t, "executor-1", resp.Executors[0].ExecutorID) + return nil + }) + + // Expect update send + mockServer.EXPECT().Send(gomock.Any()).DoAndReturn(func(resp *types.WatchNamespaceStateResponse) error { + require.Len(t, resp.Executors, 1) + require.Equal(t, "executor-2", resp.Executors[0].ExecutorID) + return nil + }) + + // Send update, then cancel + go func() { + time.Sleep(10 * time.Millisecond) + updatesChan <- map[*store.ShardOwner][]string{ + {ExecutorID: "executor-2", Metadata: map[string]string{}}: {"shard-2"}, + } + cancel() + }() + + err := handler.WatchNamespaceState(&types.WatchNamespaceStateRequest{Namespace: "test-ns"}, mockServer) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }) + + t.Run("storage error on initial state", func(t *testing.T) { + ctx := context.Background() + mockServer.EXPECT().Context().Return(ctx).AnyTimes() + mockStorage.EXPECT().GetState(gomock.Any(), "test-ns").Return(nil, errors.New("storage error")) + mockStorage.EXPECT().SubscribeToAssignmentChanges(gomock.Any(), "test-ns").Return(make(chan map[*store.ShardOwner][]string), func() {}, nil) + + err := handler.WatchNamespaceState(&types.WatchNamespaceStateRequest{Namespace: "test-ns"}, mockServer) + require.Error(t, err) + require.Contains(t, err.Error(), "get initial state") + }) +} diff --git a/service/sharddistributor/store/etcd/executorstore/etcdstore.go b/service/sharddistributor/store/etcd/executorstore/etcdstore.go index d3498345c56..f8cdeb94af3 100644 --- a/service/sharddistributor/store/etcd/executorstore/etcdstore.go +++ b/service/sharddistributor/store/etcd/executorstore/etcdstore.go @@ -288,6 +288,10 @@ func (s *executorStoreImpl) GetState(ctx context.Context, namespace string) (*st }, nil } +func (s *executorStoreImpl) SubscribeToAssignmentChanges(ctx context.Context, namespace string) (<-chan map[*store.ShardOwner][]string, func(), error) { + return s.shardCache.Subscribe(ctx, namespace) +} + func (s *executorStoreImpl) Subscribe(ctx context.Context, namespace string) (<-chan int64, error) { revisionChan := make(chan int64, 1) watchPrefix := etcdkeys.BuildExecutorPrefix(s.prefix, namespace) diff --git a/service/sharddistributor/store/etcd/executorstore/shardcache/namespaceshardcache.go b/service/sharddistributor/store/etcd/executorstore/shardcache/namespaceshardcache.go index d54415f082b..c59b15f2d3d 100644 --- a/service/sharddistributor/store/etcd/executorstore/shardcache/namespaceshardcache.go +++ b/service/sharddistributor/store/etcd/executorstore/shardcache/namespaceshardcache.go @@ -19,6 +19,7 @@ type namespaceShardToExecutor struct { sync.RWMutex shardToExecutor map[string]*store.ShardOwner + executorState map[*store.ShardOwner][]string // executor -> shardIDs executorRevision map[string]int64 namespace string etcdPrefix string @@ -26,6 +27,7 @@ type namespaceShardToExecutor struct { stopCh chan struct{} logger log.Logger client *clientv3.Client + pubSub *executorStatePubSub } func newNamespaceShardToExecutor(etcdPrefix, namespace string, client *clientv3.Client, stopCh chan struct{}, logger log.Logger) (*namespaceShardToExecutor, error) { @@ -35,6 +37,7 @@ func newNamespaceShardToExecutor(etcdPrefix, namespace string, client *clientv3. return &namespaceShardToExecutor{ shardToExecutor: make(map[string]*store.ShardOwner), + executorState: make(map[*store.ShardOwner][]string), executorRevision: make(map[string]int64), namespace: namespace, etcdPrefix: etcdPrefix, @@ -42,6 +45,7 @@ func newNamespaceShardToExecutor(etcdPrefix, namespace string, client *clientv3. stopCh: stopCh, logger: logger, client: client, + pubSub: newExecutorStatePubSub(logger, namespace), }, nil } @@ -94,6 +98,10 @@ func (n *namespaceShardToExecutor) GetExecutorModRevisionCmp() ([]clientv3.Cmp, return comparisons, nil } +func (n *namespaceShardToExecutor) Subscribe(ctx context.Context) (<-chan map[*store.ShardOwner][]string, func()) { + return n.pubSub.subscribe(ctx) +} + func (n *namespaceShardToExecutor) nameSpaceRefreashLoop() { for { select { @@ -124,7 +132,24 @@ func (n *namespaceShardToExecutor) nameSpaceRefreashLoop() { } func (n *namespaceShardToExecutor) refresh(ctx context.Context) error { + err := n.refreshExecutorState(ctx) + if err != nil { + return fmt.Errorf("refresh executor state: %w", err) + } + + n.RLock() + executorState := make(map[*store.ShardOwner][]string) + for executor, shardIDs := range n.executorState { + executorState[executor] = make([]string, len(shardIDs)) + copy(executorState[executor], shardIDs) + } + n.RUnlock() + + n.pubSub.publish(n.executorState) + return nil +} +func (n *namespaceShardToExecutor) refreshExecutorState(ctx context.Context) error { executorPrefix := etcdkeys.BuildExecutorPrefix(n.etcdPrefix, n.namespace) resp, err := n.client.Get(ctx, executorPrefix, clientv3.WithPrefix()) @@ -136,6 +161,7 @@ func (n *namespaceShardToExecutor) refresh(ctx context.Context) error { defer n.Unlock() // Clear the cache, so we don't have any stale data n.shardToExecutor = make(map[string]*store.ShardOwner) + n.executorState = make(map[*store.ShardOwner][]string) n.executorRevision = make(map[string]int64) shardOwners := make(map[string]*store.ShardOwner) @@ -154,10 +180,15 @@ func (n *namespaceShardToExecutor) refresh(ctx context.Context) error { if err != nil { return fmt.Errorf("parse assigned state: %w", err) } + + // Build both shard->executor and executor->shards mappings + shardIDs := make([]string, 0, len(assignedState.AssignedShards)) for shardID := range assignedState.AssignedShards { n.shardToExecutor[shardID] = shardOwner + shardIDs = append(shardIDs, shardID) n.executorRevision[executorID] = kv.ModRevision } + n.executorState[shardOwner] = shardIDs case etcdkeys.ExecutorMetadataKey: shardOwner := getOrCreateShardOwner(shardOwners, executorID) diff --git a/service/sharddistributor/store/etcd/executorstore/shardcache/pubsub.go b/service/sharddistributor/store/etcd/executorstore/shardcache/pubsub.go new file mode 100644 index 00000000000..8f682fdff17 --- /dev/null +++ b/service/sharddistributor/store/etcd/executorstore/shardcache/pubsub.go @@ -0,0 +1,65 @@ +package shardcache + +import ( + "context" + "sync" + + "github.com/google/uuid" + + "github.com/uber/cadence/common/log" + "github.com/uber/cadence/common/log/tag" + "github.com/uber/cadence/service/sharddistributor/store" +) + +// executorStatePubSub manages subscriptions to executor state changes +type executorStatePubSub struct { + mu sync.RWMutex + subscribers map[string]chan<- map[*store.ShardOwner][]string + logger log.Logger + namespace string +} + +func newExecutorStatePubSub(logger log.Logger, namespace string) *executorStatePubSub { + return &executorStatePubSub{ + subscribers: make(map[string]chan<- map[*store.ShardOwner][]string), + logger: logger, + namespace: namespace, + } +} + +// Subscribe returns a channel that receives executor state updates. +func (p *executorStatePubSub) subscribe(ctx context.Context) (<-chan map[*store.ShardOwner][]string, func()) { + ch := make(chan map[*store.ShardOwner][]string) + uniqueID := uuid.New().String() + + p.mu.Lock() + defer p.mu.Unlock() + p.subscribers[uniqueID] = ch + + unSub := func() { + p.unSubscribe(uniqueID) + } + + return ch, unSub +} + +func (p *executorStatePubSub) unSubscribe(uniqueID string) { + p.mu.Lock() + defer p.mu.Unlock() + delete(p.subscribers, uniqueID) +} + +// Publish sends the state to all subscribers (non-blocking) +func (p *executorStatePubSub) publish(state map[*store.ShardOwner][]string) { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, sub := range p.subscribers { + select { + case sub <- state: + default: + // Subscriber is not reading fast enough, skip this update + p.logger.Warn("Subscriber not keeping up with state updates, dropping update", tag.ShardNamespace(p.namespace)) + } + } +} diff --git a/service/sharddistributor/store/etcd/executorstore/shardcache/pubsub_test.go b/service/sharddistributor/store/etcd/executorstore/shardcache/pubsub_test.go new file mode 100644 index 00000000000..1be9ca33ce1 --- /dev/null +++ b/service/sharddistributor/store/etcd/executorstore/shardcache/pubsub_test.go @@ -0,0 +1,89 @@ +package shardcache + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + + "github.com/uber/cadence/common/log/testlogger" + "github.com/uber/cadence/service/sharddistributor/store" +) + +func TestExecutorStatePubSub_SubscribeUnsubscribe(t *testing.T) { + defer goleak.VerifyNone(t) + pubsub := newExecutorStatePubSub(testlogger.New(t), "test-ns") + + ch, unsub := pubsub.subscribe(context.Background()) + assert.NotNil(t, ch) + assert.Len(t, pubsub.subscribers, 1) + + unsub() + assert.Len(t, pubsub.subscribers, 0) + + // Unsubscribe is idempotent + unsub() + assert.Len(t, pubsub.subscribers, 0) +} + +func TestExecutorStatePubSub_Publish(t *testing.T) { + defer goleak.VerifyNone(t) + + t.Run("no subscribers doesn't panic", func(t *testing.T) { + pubsub := newExecutorStatePubSub(testlogger.New(t), "test-ns") + require.NotPanics(t, func() { + pubsub.publish(map[*store.ShardOwner][]string{}) + }) + }) + + t.Run("multiple subscribers receive updates", func(t *testing.T) { + pubsub := newExecutorStatePubSub(testlogger.New(t), "test-ns") + ch1, unsub1 := pubsub.subscribe(context.Background()) + ch2, unsub2 := pubsub.subscribe(context.Background()) + defer unsub1() + defer unsub2() + + testState := map[*store.ShardOwner][]string{ + {ExecutorID: "exec-1", Metadata: map[string]string{}}: {"shard-1"}, + } + + var wg sync.WaitGroup + wg.Add(2) + go func() { + state := <-ch1 + assert.Equal(t, testState, state) + wg.Done() + }() + go func() { + state := <-ch2 + assert.Equal(t, testState, state) + wg.Done() + }() + time.Sleep(10 * time.Millisecond) + + pubsub.publish(testState) + + wg.Wait() + }) + + t.Run("non-blocking publish to slow consumer", func(t *testing.T) { + pubsub := newExecutorStatePubSub(testlogger.New(t), "test-ns") + + // We create a subscriber that doesn't read from the channel, this should still not block + _, slowUnsub := pubsub.subscribe(context.Background()) + defer slowUnsub() + + testState := map[*store.ShardOwner][]string{ + {ExecutorID: "exec-1", Metadata: map[string]string{}}: {"shard-1"}, + } + + // We do not read from the slow channel, this should still not block + for range 10 { + pubsub.publish(testState) + } + }) +} diff --git a/service/sharddistributor/store/etcd/executorstore/shardcache/shardcache.go b/service/sharddistributor/store/etcd/executorstore/shardcache/shardcache.go index 15d94aaf237..041ec9e6302 100644 --- a/service/sharddistributor/store/etcd/executorstore/shardcache/shardcache.go +++ b/service/sharddistributor/store/etcd/executorstore/shardcache/shardcache.go @@ -62,6 +62,16 @@ func (s *ShardToExecutorCache) GetExecutorModRevisionCmp(namespace string) ([]cl return namespaceShardToExecutor.GetExecutorModRevisionCmp() } +func (s *ShardToExecutorCache) Subscribe(ctx context.Context, namespace string) (<-chan map[*store.ShardOwner][]string, func(), error) { + namespaceShardToExecutor, err := s.getNamespaceShardToExecutor(namespace) + if err != nil { + return nil, nil, fmt.Errorf("get namespace shard to executor: %w", err) + } + + ch, unSub := namespaceShardToExecutor.Subscribe(ctx) + return ch, unSub, nil +} + func (s *ShardToExecutorCache) getNamespaceShardToExecutor(namespace string) (*namespaceShardToExecutor, error) { s.RLock() namespaceShardToExecutor, ok := s.namespaceToShards[namespace] diff --git a/service/sharddistributor/store/store.go b/service/sharddistributor/store/store.go index a9500408933..5d25fa8ecbe 100644 --- a/service/sharddistributor/store/store.go +++ b/service/sharddistributor/store/store.go @@ -63,6 +63,7 @@ type Store interface { DeleteShardStats(ctx context.Context, namespace string, shardIDs []string, guard GuardFunc) error GetShardOwner(ctx context.Context, namespace, shardID string) (*ShardOwner, error) + SubscribeToAssignmentChanges(ctx context.Context, namespace string) (<-chan map[*ShardOwner][]string, func(), error) AssignShard(ctx context.Context, namespace, shardID, executorID string) error GetHeartbeat(ctx context.Context, namespace string, executorID string) (*HeartbeatState, *AssignedState, error) diff --git a/service/sharddistributor/store/store_mock.go b/service/sharddistributor/store/store_mock.go index a246f8e3f85..685b4e01070 100644 --- a/service/sharddistributor/store/store_mock.go +++ b/service/sharddistributor/store/store_mock.go @@ -194,3 +194,19 @@ func (mr *MockStoreMockRecorder) Subscribe(ctx, namespace any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockStore)(nil).Subscribe), ctx, namespace) } + +// SubscribeToAssignmentChanges mocks base method. +func (m *MockStore) SubscribeToAssignmentChanges(ctx context.Context, namespace string) (<-chan map[*ShardOwner][]string, func(), error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscribeToAssignmentChanges", ctx, namespace) + ret0, _ := ret[0].(<-chan map[*ShardOwner][]string) + ret1, _ := ret[1].(func()) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// SubscribeToAssignmentChanges indicates an expected call of SubscribeToAssignmentChanges. +func (mr *MockStoreMockRecorder) SubscribeToAssignmentChanges(ctx, namespace any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeToAssignmentChanges", reflect.TypeOf((*MockStore)(nil).SubscribeToAssignmentChanges), ctx, namespace) +} diff --git a/service/sharddistributor/store/wrappers/metered/store_generated.go b/service/sharddistributor/store/wrappers/metered/store_generated.go index 91f1523ecfc..c47d478a71a 100644 --- a/service/sharddistributor/store/wrappers/metered/store_generated.go +++ b/service/sharddistributor/store/wrappers/metered/store_generated.go @@ -128,3 +128,13 @@ func (c *meteredStore) Subscribe(ctx context.Context, namespace string) (ch1 <-c err = c.call(metrics.ShardDistributorStoreSubscribeScope, op, metrics.NamespaceTag(namespace)) return } + +func (c *meteredStore) SubscribeToAssignmentChanges(ctx context.Context, namespace string) (ch1 <-chan map[*store.ShardOwner][]string, f1 func(), err error) { + op := func() error { + ch1, f1, err = c.wrapped.SubscribeToAssignmentChanges(ctx, namespace) + return err + } + + err = c.call(metrics.ShardDistributorStoreSubscribeToAssignmentChangesScope, op, metrics.NamespaceTag(namespace)) + return +}