diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index 7799fcfe44b..d888573ff11 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -97,6 +97,7 @@ type ( membershipResolver membership.Resolver isolationState isolationgroup.State timeSource clock.TimeSource + notificationVersion int64 } // HistoryInfo consists of two integer regarding the history size and history count @@ -162,6 +163,7 @@ func NewEngine( } func (e *matchingEngineImpl) Start() { + e.registerDomainFailoverCallback() } func (e *matchingEngineImpl) Stop() { @@ -170,6 +172,7 @@ func (e *matchingEngineImpl) Stop() { for _, l := range e.getTaskLists(math.MaxInt32) { l.Stop() } + e.unregisterDomainFailoverCallback() e.shutdownCompletion.Wait() } @@ -535,7 +538,7 @@ pollLoop: pollerCtx = tasklist.ContextWithIsolationGroup(pollerCtx, req.GetIsolationGroup()) tlMgr, err := e.getTaskListManager(taskListID, taskListKind) if err != nil { - return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) + return nil, fmt.Errorf("couldn't load tasklist manager: %w", err) } startT := time.Now() // Record the start time task, err := tlMgr.GetTask(pollerCtx, nil) @@ -724,7 +727,7 @@ pollLoop: taskListKind := request.TaskList.Kind tlMgr, err := e.getTaskListManager(taskListID, taskListKind) if err != nil { - return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) + return nil, fmt.Errorf("couldn't load tasklist manager: %w", err) } startT := time.Now() // Record the start time task, err := tlMgr.GetTask(pollerCtx, maxDispatch) @@ -1425,6 +1428,66 @@ func (e *matchingEngineImpl) isShuttingDown() bool { } } +func (e *matchingEngineImpl) domainChangeCallback(nextDomains []*cache.DomainCacheEntry) { + newNotificationVersion := e.notificationVersion + + for _, domain := range nextDomains { + if domain.GetNotificationVersion() > newNotificationVersion { + newNotificationVersion = domain.GetNotificationVersion() + } + + if !isDomainEligibleToDisconnectPollers(domain, e.notificationVersion) { + continue + } + + req := &types.GetTaskListsByDomainRequest{ + Domain: domain.GetInfo().Name, + } + + resp, err := e.GetTaskListsByDomain(nil, req) + if err != nil { + continue + } + + for taskListName := range resp.DecisionTaskListMap { + e.disconnectTaskListPollersAfterDomainFailover(taskListName, domain, persistence.TaskListTypeDecision) + } + + for taskListName := range resp.ActivityTaskListMap { + e.disconnectTaskListPollersAfterDomainFailover(taskListName, domain, persistence.TaskListTypeActivity) + } + } + e.notificationVersion = newNotificationVersion +} + +func (e *matchingEngineImpl) registerDomainFailoverCallback() { + e.domainCache.RegisterDomainChangeCallback( + service.Matching, + func(_ cache.DomainCache, _ cache.PrepareCallbackFn, _ cache.CallbackFn) {}, + func() {}, + e.domainChangeCallback) +} + +func (e *matchingEngineImpl) unregisterDomainFailoverCallback() { + e.domainCache.UnregisterDomainChangeCallback(service.Matching) +} + +func (e *matchingEngineImpl) disconnectTaskListPollersAfterDomainFailover(taskListName string, domain *cache.DomainCacheEntry, taskType int) { + taskList, err := tasklist.NewIdentifier(domain.GetInfo().ID, taskListName, taskType) + if err != nil { + return + } + tlMgr, err := e.getTaskListManager(taskList, types.TaskListKindNormal.Ptr()) + if err != nil { + e.logger.Error("Couldn't load tasklist manager", tag.Error(err)) + return + } + + if tlMgr.GetDomainActiveCluster() != "" && tlMgr.GetDomainActiveCluster() != domain.GetReplicationConfig().ActiveClusterName { + tlMgr.DisconnectBlockedPollers(&domain.GetReplicationConfig().ActiveClusterName) + } +} + func (m *lockableQueryTaskMap) put(key string, value chan *queryResult) { m.Lock() defer m.Unlock() @@ -1451,3 +1514,10 @@ func isMatchingRetryableError(err error) bool { } return true } + +func isDomainEligibleToDisconnectPollers(domain *cache.DomainCacheEntry, currentVersion int64) bool { + return domain.IsGlobalDomain() && + domain.GetReplicationConfig() != nil && + !domain.GetReplicationConfig().IsActiveActive() && + domain.GetNotificationVersion() > currentVersion +} diff --git a/service/matching/handler/engine_integration_test.go b/service/matching/handler/engine_integration_test.go index d450c003c0f..57aa9c95dae 100644 --- a/service/matching/handler/engine_integration_test.go +++ b/service/matching/handler/engine_integration_test.go @@ -131,6 +131,8 @@ func (s *matchingEngineSuite) SetupTest() { s.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(cache.CreateDomainCacheEntry(matchingTestDomainName), nil).AnyTimes() s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(cache.CreateDomainCacheEntry(matchingTestDomainName), nil).AnyTimes() s.mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return(matchingTestDomainName, nil).AnyTimes() + s.mockDomainCache.EXPECT().RegisterDomainChangeCallback(service.Matching, gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + s.mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).AnyTimes() s.mockMembershipResolver = membership.NewMockResolver(s.controller) s.mockMembershipResolver.EXPECT().Lookup(gomock.Any(), gomock.Any()).Return(membership.HostInfo{}, nil).AnyTimes() s.mockMembershipResolver.EXPECT().WhoAmI().Return(membership.HostInfo{}, nil).AnyTimes() diff --git a/service/matching/handler/engine_test.go b/service/matching/handler/engine_test.go index 12f59910fce..9d140304db5 100644 --- a/service/matching/handler/engine_test.go +++ b/service/matching/handler/engine_test.go @@ -40,6 +40,7 @@ import ( "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/membership" "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/service" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/matching/config" @@ -645,7 +646,11 @@ func TestWaitForQueryResult(t *testing.T) { func TestIsShuttingDown(t *testing.T) { wg := sync.WaitGroup{} wg.Add(0) + mockDomainCache := cache.NewMockDomainCache(gomock.NewController(t)) + mockDomainCache.EXPECT().RegisterDomainChangeCallback(service.Matching, gomock.Any(), gomock.Any(), gomock.Any()).Times(1) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) e := matchingEngineImpl{ + domainCache: mockDomainCache, shutdownCompletion: &wg, shutdown: make(chan struct{}), } @@ -1138,3 +1143,121 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { }) } } + +func Test_domainChangeCallback(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + + clusters := []string{"cluster0", "cluster1"} + + mockTaskListManagerGlobal1 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerGlobal2 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerGlobal3 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerLocal1 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerActiveActive1 := tasklist.NewMockManager(mockCtrl) + + tlIdentifierDecisionGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeDecision) + tlIdentifierActivityGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeActivity) + tlIdentifierDecisionGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "global-domain-2", persistence.TaskListTypeDecision) + tlIdentifierActivityGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "global-domain-2", persistence.TaskListTypeActivity) + tlIdentifierDecisionGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "global-domain-3", persistence.TaskListTypeDecision) + tlIdentifierActivityGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "global-domain-3", persistence.TaskListTypeActivity) + tlIdentifierDecisionLocal1, _ := tasklist.NewIdentifier("local-domain-1-id", "local-domain-1", persistence.TaskListTypeDecision) + tlIdentifierActivityLocal1, _ := tasklist.NewIdentifier("local-domain-1-id", "local-domain-1", persistence.TaskListTypeActivity) + tlIdentifierDecisionActiveActive1, _ := tasklist.NewIdentifier("active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeDecision) + tlIdentifierActivityActiveActive1, _ := tasklist.NewIdentifier("active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeActivity) + + engine := &matchingEngineImpl{ + domainCache: mockDomainCache, + notificationVersion: 0, + config: defaultTestConfig(), + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *tlIdentifierDecisionGlobal1: mockTaskListManagerGlobal1, + *tlIdentifierActivityGlobal1: mockTaskListManagerGlobal1, + *tlIdentifierDecisionGlobal2: mockTaskListManagerGlobal2, + *tlIdentifierActivityGlobal2: mockTaskListManagerGlobal2, + *tlIdentifierDecisionGlobal3: mockTaskListManagerGlobal3, + *tlIdentifierActivityGlobal3: mockTaskListManagerGlobal3, + *tlIdentifierDecisionLocal1: mockTaskListManagerLocal1, + *tlIdentifierActivityLocal1: mockTaskListManagerLocal1, + *tlIdentifierDecisionActiveActive1: mockTaskListManagerActiveActive1, + *tlIdentifierActivityActiveActive1: mockTaskListManagerActiveActive1, + }, + } + + mockTaskListManagerGlobal1.EXPECT().DisconnectBlockedPollers(&clusters[0]).Times(0) + mockTaskListManagerGlobal2.EXPECT().GetDomainActiveCluster().Return(clusters[0]).Times(4) + mockTaskListManagerGlobal2.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).Times(2) + mockTaskListManagerGlobal2.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(2) + mockTaskListManagerGlobal2.EXPECT().DisconnectBlockedPollers(&clusters[1]).Times(2) + mockTaskListManagerGlobal3.EXPECT().GetDomainActiveCluster().Return(clusters[1]).Times(4) + mockTaskListManagerGlobal3.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).Times(2) + mockTaskListManagerGlobal3.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(2) + mockTaskListManagerGlobal3.EXPECT().DisconnectBlockedPollers(clusters[1]).Times(0) + mockTaskListManagerLocal1.EXPECT().DisconnectBlockedPollers(gomock.Any()).Times(0) + mockTaskListManagerActiveActive1.EXPECT().DisconnectBlockedPollers(gomock.Any()).Times(0) + mockDomainCache.EXPECT().GetDomainID("global-domain-2").Return("global-domain-2-id", nil).Times(1) + mockDomainCache.EXPECT().GetDomainID("global-domain-3").Return("global-domain-3-id", nil).Times(1) + + domains := []*cache.DomainCacheEntry{ + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "global-domain-1", ID: "global-domain-1-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: clusters[0], Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "cluster0"}, {ClusterName: "cluster1"}}}, + 0, + nil, + 0, + 0, + 0, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "global-domain-2", ID: "global-domain-2-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: clusters[1], Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "cluster0"}, {ClusterName: "cluster1"}}}, + 0, + nil, + 0, + 0, + 4, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "global-domain-3", ID: "global-domain-3-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: clusters[1], Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "cluster0"}, {ClusterName: "cluster1"}}}, + 0, + nil, + 0, + 0, + 5, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "local-domain-1", ID: "local-domain-1-id"}, + nil, + false, + nil, + 0, + nil, + 0, + 0, + 3, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "active-active-domain-1", ID: "active-active-domain-1-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusters: &persistence.ActiveClustersConfig{}}, + 0, + nil, + 0, + 0, + 3, + ), + } + + engine.domainChangeCallback(domains) + + assert.Equal(t, int64(5), engine.notificationVersion) +} diff --git a/service/matching/handler/handler.go b/service/matching/handler/handler.go index 082503c7763..9f6d76cb538 100644 --- a/service/matching/handler/handler.go +++ b/service/matching/handler/handler.go @@ -89,6 +89,7 @@ func NewHandler( // Start starts the handler func (h *handlerImpl) Start() { + h.engine.Start() h.startWG.Done() } diff --git a/service/matching/handler/handler_test.go b/service/matching/handler/handler_test.go index aaebd9b8eb1..f2f3112f2c4 100644 --- a/service/matching/handler/handler_test.go +++ b/service/matching/handler/handler_test.go @@ -123,6 +123,8 @@ func (s *handlerSuite) TestStart() { cfg := config.NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewInMemoryClient(), s.mockResource.Logger), "matching-test", getIsolationGroupsHelper) handler := s.getHandler(cfg) + s.mockEngine.EXPECT().Start().Times(1) + handler.Start() } @@ -132,6 +134,7 @@ func (s *handlerSuite) TestStop() { cfg := config.NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewInMemoryClient(), s.mockResource.Logger), "matching-test", getIsolationGroupsHelper) handler := s.getHandler(cfg) + s.mockEngine.EXPECT().Start().Times(1) s.mockEngine.EXPECT().Stop().Times(1) handler.Start() diff --git a/service/matching/handler/membership_test.go b/service/matching/handler/membership_test.go index 21720f8d8cc..84c926b8c28 100644 --- a/service/matching/handler/membership_test.go +++ b/service/matching/handler/membership_test.go @@ -207,14 +207,17 @@ func TestSubscriptionAndShutdown(t *testing.T) { shutdownWG := &sync.WaitGroup{} shutdownWG.Add(1) + mockDomainCache := cache.NewMockDomainCache(ctrl) + e := matchingEngineImpl{ shutdownCompletion: shutdownWG, membershipResolver: m, config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, } // anytimes here because this is quite a racy test and the actual assertions for the unsubscription logic will be separated out @@ -228,6 +231,7 @@ func TestSubscriptionAndShutdown(t *testing.T) { } inc <- &m }) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) go func() { // then call stop so the test can finish @@ -242,6 +246,8 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { ctrl := gomock.NewController(t) m := membership.NewMockResolver(ctrl) + mockDomainCache := cache.NewMockDomainCache(ctrl) + shutdownWG := sync.WaitGroup{} shutdownWG.Add(1) @@ -251,8 +257,9 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, } // this should trigger the error case on a membership event @@ -268,6 +275,8 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { inc <- &m }) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) + go func() { // then call stop so the test can finish time.Sleep(time.Second) @@ -281,6 +290,8 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { ctrl := gomock.NewController(t) m := membership.NewMockResolver(ctrl) + mockDomainCache := cache.NewMockDomainCache(ctrl) + logger, logs := testlogger.NewObserved(t) shutdownWG := sync.WaitGroup{} @@ -292,8 +303,9 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: logger, + shutdown: make(chan struct{}), + logger: logger, + domainCache: mockDomainCache, } // this should trigger the error case on a membership event @@ -302,6 +314,8 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { m.EXPECT().Subscribe(service.Matching, "matching-engine", gomock.Any()). Return(errors.New("matching-engine is already subscribed to updates")) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).AnyTimes() + go func() { // then call stop so the test can finish time.Sleep(time.Second) @@ -324,9 +338,12 @@ func TestGetTasklistManagerShutdownScenario(t *testing.T) { ctrl := gomock.NewController(t) m := membership.NewMockResolver(ctrl) + mockDomainCache := cache.NewMockDomainCache(ctrl) + self := membership.NewDetailedHostInfo("self", "self", nil) m.EXPECT().WhoAmI().Return(self, nil).AnyTimes() + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) shutdownWG := sync.WaitGroup{} shutdownWG.Add(0) @@ -337,8 +354,9 @@ func TestGetTasklistManagerShutdownScenario(t *testing.T) { config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, } // set this engine to be shutting down so as to trigger the tasklistGetTasklistByID guard diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go index f7dec24a953..020f4beb31d 100644 --- a/service/matching/tasklist/interfaces.go +++ b/service/matching/tasklist/interfaces.go @@ -64,6 +64,8 @@ type ( UpdateTaskListPartitionConfig(context.Context, *types.TaskListPartitionConfig) error RefreshTaskListPartitionConfig(context.Context, *types.TaskListPartitionConfig) error LoadBalancerHints() *types.LoadBalancerHints + GetDomainActiveCluster() string + DisconnectBlockedPollers(domainActiveCluster *string) } TaskMatcher interface { @@ -76,6 +78,7 @@ type ( PollForQuery(ctx context.Context) (*InternalTask, error) UpdateRatelimit(rps *float64) Rate() float64 + RefreshCancelContext() } Forwarder interface { diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go index 67c14650bb3..5fcbc5ad981 100644 --- a/service/matching/tasklist/interfaces_mock.go +++ b/service/matching/tasklist/interfaces_mock.go @@ -84,6 +84,18 @@ func (mr *MockManagerMockRecorder) DescribeTaskList(includeTaskListStatus any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeTaskList", reflect.TypeOf((*MockManager)(nil).DescribeTaskList), includeTaskListStatus) } +// DisconnectBlockedPollers mocks base method. +func (m *MockManager) DisconnectBlockedPollers(domainActiveCluster *string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectBlockedPollers", domainActiveCluster) +} + +// DisconnectBlockedPollers indicates an expected call of DisconnectBlockedPollers. +func (mr *MockManagerMockRecorder) DisconnectBlockedPollers(domainActiveCluster any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectBlockedPollers", reflect.TypeOf((*MockManager)(nil).DisconnectBlockedPollers), domainActiveCluster) +} + // DispatchQueryTask mocks base method. func (m *MockManager) DispatchQueryTask(ctx context.Context, taskID string, request *types.MatchingQueryWorkflowRequest) (*types.MatchingQueryWorkflowResponse, error) { m.ctrl.T.Helper() @@ -127,6 +139,20 @@ func (mr *MockManagerMockRecorder) GetAllPollerInfo() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPollerInfo", reflect.TypeOf((*MockManager)(nil).GetAllPollerInfo)) } +// GetDomainActiveCluster mocks base method. +func (m *MockManager) GetDomainActiveCluster() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDomainActiveCluster") + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDomainActiveCluster indicates an expected call of GetDomainActiveCluster. +func (mr *MockManagerMockRecorder) GetDomainActiveCluster() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDomainActiveCluster", reflect.TypeOf((*MockManager)(nil).GetDomainActiveCluster)) +} + // GetTask mocks base method. func (m *MockManager) GetTask(ctx context.Context, maxDispatchPerSecond *float64) (*InternalTask, error) { m.ctrl.T.Helper() @@ -419,6 +445,18 @@ func (mr *MockTaskMatcherMockRecorder) Rate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rate", reflect.TypeOf((*MockTaskMatcher)(nil).Rate)) } +// RefreshCancelContext mocks base method. +func (m *MockTaskMatcher) RefreshCancelContext() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RefreshCancelContext") +} + +// RefreshCancelContext indicates an expected call of RefreshCancelContext. +func (mr *MockTaskMatcherMockRecorder) RefreshCancelContext() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshCancelContext", reflect.TypeOf((*MockTaskMatcher)(nil).RefreshCancelContext)) +} + // UpdateRatelimit mocks base method. func (m *MockTaskMatcher) UpdateRatelimit(rps *float64) { m.ctrl.T.Helper() diff --git a/service/matching/tasklist/matcher.go b/service/matching/tasklist/matcher.go index ad50c1825f7..5f7f6e08dd6 100644 --- a/service/matching/tasklist/matcher.go +++ b/service/matching/tasklist/matcher.go @@ -464,10 +464,14 @@ func (tm *taskMatcherImpl) PollForQuery(ctx context.Context) (*InternalTask, err tm.scope.RecordTimer(metrics.PollLocalMatchLatencyPerTaskList, time.Since(startT)) return task, nil } + + ctxWithCancelPropagation, stopFn := ctxutils.WithPropagatedContextCancel(ctx, tm.cancelCtx) + defer stopFn() + // there is no local poller available to pickup this task. Now block waiting // either for a local poller or a forwarding token to be available. When a // forwarding token becomes available, send this poll to a parent partition - return tm.pollOrForward(ctx, startT, "", nil, nil, tm.queryTaskC) + return tm.pollOrForward(ctxWithCancelPropagation, startT, "", nil, nil, tm.queryTaskC) } // UpdateRatelimit updates the task dispatch rate @@ -488,6 +492,10 @@ func (tm *taskMatcherImpl) Rate() float64 { return rate } +func (tm *taskMatcherImpl) RefreshCancelContext() { + tm.cancelCtx, tm.cancelFunc = context.WithCancel(context.Background()) +} + func (tm *taskMatcherImpl) pollOrForward( ctx context.Context, startT time.Time, diff --git a/service/matching/tasklist/matcher_test.go b/service/matching/tasklist/matcher_test.go index 9f0ab19d696..2d2d9e91a01 100644 --- a/service/matching/tasklist/matcher_test.go +++ b/service/matching/tasklist/matcher_test.go @@ -568,12 +568,28 @@ func (t *MatcherTestSuite) TestPollersDisconnectedAfterDisconnectBlockedPollers( t.disableRemoteForwarding() t.matcher.DisconnectBlockedPollers() - longPollingCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + longPollingCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() task, err := t.matcher.Poll(longPollingCtx, "") t.ErrorIs(err, ErrNoTasks, "closed matcher should result in no tasks") t.Nil(task) + t.NoError(longPollingCtx.Err(), "the parent context was not cancelled, the child context was cancelled") +} + +func (t *MatcherTestSuite) TestPollersConnectedAfterDisconnectBlockedPollersSetCancelContext() { + defer goleak.VerifyNone(t.T()) + t.disableRemoteForwarding() + t.matcher.DisconnectBlockedPollers() + t.matcher.RefreshCancelContext() + + longPollingCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + task, err := t.matcher.Poll(longPollingCtx, "") + t.ErrorIs(err, ErrNoTasks, "no tasks to be matched to poller") + t.Nil(task) + t.ErrorIs(longPollingCtx.Err(), context.DeadlineExceeded, "the child context wasn't cancelled, the parent context timed out") } func (t *MatcherTestSuite) TestRemotePoll() { diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index 2bea1116278..0beb958186a 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -94,26 +94,27 @@ type ( // Single task list in memory state taskListManagerImpl struct { - createTime time.Time - enableIsolation bool - taskListID *Identifier - taskListKind types.TaskListKind // sticky taskList has different process in persistence - config *config.TaskListConfig - db *taskListDB - taskWriter *taskWriter - taskReader *taskReader // reads tasks from db and async matches it with poller - liveness *liveness.Liveness - taskGC *taskGC - taskAckManager messaging.AckManager // tracks ackLevel for delivered messages - matcher TaskMatcher // for matching a task producer with a poller - clusterMetadata cluster.Metadata - domainCache cache.DomainCache - isolationState isolationgroup.State - logger log.Logger - scope metrics.Scope - timeSource clock.TimeSource - matchingClient matching.Client - domainName string + createTime time.Time + enableIsolation bool + taskListID *Identifier + taskListKind types.TaskListKind // sticky taskList has different process in persistence + config *config.TaskListConfig + db *taskListDB + taskWriter *taskWriter + taskReader *taskReader // reads tasks from db and async matches it with poller + liveness *liveness.Liveness + taskGC *taskGC + taskAckManager messaging.AckManager // tracks ackLevel for delivered messages + matcher TaskMatcher // for matching a task producer with a poller + clusterMetadata cluster.Metadata + domainCache cache.DomainCache + isolationState isolationgroup.State + logger log.Logger + scope metrics.Scope + timeSource clock.TimeSource + matchingClient matching.Client + domainName string + domainActiveCluster string // pollers stores poller which poll from this tasklist in last few minutes pollers poller.Manager startWG sync.WaitGroup // ensures that background processes do not start until setup is ready @@ -155,11 +156,21 @@ func NewManager( createTime time.Time, historyService history.Client, ) (Manager, error) { - domainName, err := domainCache.GetDomainName(taskList.GetDomainID()) + domain, err := domainCache.GetDomainByID(taskList.GetDomainID()) if err != nil { return nil, err } + domainName := domain.GetInfo().Name + + // set domainActiveCluster for global active-passive domains and local domains + var domainActiveCluster string + if domain.IsGlobalDomain() && domain.GetReplicationConfig() != nil && !domain.GetReplicationConfig().IsActiveActive() { + domainActiveCluster = domain.GetReplicationConfig().ActiveClusterName + } else if !domain.IsGlobalDomain() { + domainActiveCluster = clusterMetadata.GetCurrentClusterName() + } + taskListConfig := newTaskListConfig(taskList, cfg, domainName) if taskListKind == nil { @@ -172,23 +183,24 @@ func NewManager( db := newTaskListDB(taskManager, taskList.GetDomainID(), domainName, taskList.GetName(), taskList.GetType(), int(*taskListKind), logger) tlMgr := &taskListManagerImpl{ - createTime: createTime, - enableIsolation: taskListConfig.EnableTasklistIsolation(), - domainCache: domainCache, - clusterMetadata: clusterMetadata, - isolationState: isolationState, - taskListID: taskList, - taskListKind: *taskListKind, - logger: logger.WithTags(tag.WorkflowDomainName(domainName), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskList.GetType())), - db: db, - taskAckManager: messaging.NewAckManager(logger), - taskGC: newTaskGC(db, taskListConfig), - config: taskListConfig, - matchingClient: matchingClient, - domainName: domainName, - scope: scope, - timeSource: timeSource, - closeCallback: closeCallback, + createTime: createTime, + enableIsolation: taskListConfig.EnableTasklistIsolation(), + domainCache: domainCache, + clusterMetadata: clusterMetadata, + isolationState: isolationState, + taskListID: taskList, + taskListKind: *taskListKind, + logger: logger.WithTags(tag.WorkflowDomainName(domainName), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskList.GetType())), + db: db, + taskAckManager: messaging.NewAckManager(logger), + taskGC: newTaskGC(db, taskListConfig), + config: taskListConfig, + matchingClient: matchingClient, + domainName: domainName, + domainActiveCluster: domainActiveCluster, + scope: scope, + timeSource: timeSource, + closeCallback: closeCallback, throttleRetry: backoff.NewThrottleRetry( backoff.WithRetryPolicy(persistenceOperationRetryPolicy), backoff.WithRetryableError(persistence.IsTransientError), @@ -757,6 +769,18 @@ func (c *taskListManagerImpl) DescribeTaskList(includeTaskListStatus bool) *type return response } +func (c *taskListManagerImpl) GetDomainActiveCluster() string { + return c.domainActiveCluster +} + +func (c *taskListManagerImpl) DisconnectBlockedPollers(domainActiveCluster *string) { + if domainActiveCluster != nil { + c.domainActiveCluster = *domainActiveCluster + } + c.matcher.DisconnectBlockedPollers() + c.matcher.RefreshCancelContext() +} + func (c *taskListManagerImpl) String() string { buf := new(bytes.Buffer) if c.taskListID.GetType() == persistence.TaskListTypeActivity { diff --git a/service/matching/tasklist/task_list_manager_test.go b/service/matching/tasklist/task_list_manager_test.go index ca5549420fa..080fa380c37 100644 --- a/service/matching/tasklist/task_list_manager_test.go +++ b/service/matching/tasklist/task_list_manager_test.go @@ -85,7 +85,8 @@ func setupMocksForTaskListManager(t *testing.T, taskListID *Identifier, taskList mockTimeSource: clock.NewMockedTimeSource(), dynamicClient: dynamicClient, } - deps.mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("domainName", nil).Times(1) + domain := cache.CreateDomainCacheEntry("domainName") + deps.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(domain, nil).Times(1) config := config.NewConfig(dynamicconfig.NewCollection(dynamicClient, logger), "hostname", getIsolationgroupsHelper) mockHistoryService := history.NewMockClient(ctrl) @@ -1822,6 +1823,130 @@ func TestGetNumPartitions(t *testing.T) { assert.NotPanics(t, func() { tlm.matcher.UpdateRatelimit(common.Ptr(float64(100))) }) } +func TestGetDomainActiveCluster(t *testing.T) { + tests := []struct { + name string + domainCacheEntry *cache.DomainCacheEntry + expectedCluster string + }{ + { + name: "global active-passive domain", + domainCacheEntry: cache.NewGlobalDomainCacheEntryForTest( + &persistence.DomainInfo{ID: constants.TestDomainID, Name: constants.TestDomainName}, + &persistence.DomainConfig{Retention: 1}, + &persistence.DomainReplicationConfig{ + ActiveClusterName: cluster.TestCurrentClusterName, + Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: cluster.TestCurrentClusterName}, {ClusterName: cluster.TestAlternativeClusterName}}, + }, + 1, + ), + expectedCluster: cluster.TestCurrentClusterName, + }, + { + name: "local domain", + domainCacheEntry: cache.NewLocalDomainCacheEntryForTest( + &persistence.DomainInfo{ID: constants.TestDomainID, Name: constants.TestDomainName}, + &persistence.DomainConfig{Retention: 1}, + cluster.TestCurrentClusterName, + ), + expectedCluster: cluster.TestCurrentClusterName, + }, + { + name: "global active-active domain", + domainCacheEntry: cache.NewGlobalDomainCacheEntryForTest( + &persistence.DomainInfo{ID: constants.TestDomainID, Name: constants.TestDomainName}, + &persistence.DomainConfig{Retention: 1}, + &persistence.DomainReplicationConfig{ + ActiveClusters: &persistence.ActiveClustersConfig{RegionToClusterMap: map[string]persistence.ActiveClusterConfig{}}, + Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: cluster.TestCurrentClusterName}, {ClusterName: cluster.TestAlternativeClusterName}}, + }, + 1, + ), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + controller := gomock.NewController(t) + logger := testlogger.New(t) + timeSource := clock.NewMockedTimeSource() + + tm := NewTestTaskManager(t, logger, timeSource) + mockIsolationState := isolationgroup.NewMockState(controller) + mockIsolationState.EXPECT().IsDrained(gomock.Any(), "domainName", gomock.Any()).Return(false, nil).AnyTimes() + mockDomainCache := cache.NewMockDomainCache(controller) + mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(tc.domainCacheEntry, nil).AnyTimes() + mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return("domainName", nil).AnyTimes() + mockHistoryService := history.NewMockClient(controller) + tl := "tl" + dID := "domain" + tlID, err := NewIdentifier(dID, tl, persistence.TaskListTypeActivity) + if err != nil { + panic(err) + } + tlKind := types.TaskListKindNormal + tlMgr, err := NewManager( + mockDomainCache, + logger, + metrics.NewClient(tally.NoopScope, metrics.Matching), + tm, + cluster.GetTestClusterMetadata(true), + mockIsolationState, + nil, + func(Manager) {}, + tlID, + &tlKind, + config.NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewInMemoryClient(), logger), "hostname", getIsolationgroupsHelper), + timeSource, + timeSource.Now(), + mockHistoryService, + ) + + activeClusterName := tlMgr.GetDomainActiveCluster() + assert.Equal(t, tc.expectedCluster, activeClusterName) + }) + } +} + +func TestDisconnectBlockedPollers(t *testing.T) { + tests := []struct { + name string + newActiveClusterName *string + }{ + { + name: "test-disconnect-blocked-pollers and update active cluster", + newActiveClusterName: common.StringPtr("new-active-cluster"), + }, + { + name: "test-disconnect-blocked-pollers and not update active cluster", + newActiveClusterName: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) + require.NoError(t, err) + + tlm, _ := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + + mockMatcher := NewMockTaskMatcher(gomock.NewController(t)) + tlm.matcher = mockMatcher + mockMatcher.EXPECT().DisconnectBlockedPollers().Times(1) + mockMatcher.EXPECT().RefreshCancelContext().Times(1) + currentActiveClusterName := tlm.GetDomainActiveCluster() + + tlm.DisconnectBlockedPollers(tc.newActiveClusterName) + + if tc.newActiveClusterName == nil { + assert.Equal(t, currentActiveClusterName, tlm.GetDomainActiveCluster()) + } else { + assert.Equal(t, *tc.newActiveClusterName, tlm.GetDomainActiveCluster()) + } + }) + } +} + func partitions(num int) map[int]*types.TaskListPartition { result := make(map[int]*types.TaskListPartition, num) for i := 0; i < num; i++ { diff --git a/simulation/replication/testdata/replication_simulation_default.yaml b/simulation/replication/testdata/replication_simulation_default.yaml index 658b1140d41..59c2284bcab 100644 --- a/simulation/replication/testdata/replication_simulation_default.yaml +++ b/simulation/replication/testdata/replication_simulation_default.yaml @@ -29,7 +29,7 @@ operations: # failoverTimeoutSec: 5 # unset means force failover. setting it means graceful failover request - op: validate - at: 120s # todo: this should work at 40s mark + at: 40s workflowID: wf1 cluster: cluster1 want: