diff --git a/service/matching/handler/engine.go b/service/matching/handler/engine.go index 7799fcfe44b..2f7d89d8863 100644 --- a/service/matching/handler/engine.go +++ b/service/matching/handler/engine.go @@ -79,24 +79,25 @@ type ( } matchingEngineImpl struct { - shutdownCompletion *sync.WaitGroup - shutdown chan struct{} - taskManager persistence.TaskManager - clusterMetadata cluster.Metadata - historyService history.Client - matchingClient matching.Client - tokenSerializer common.TaskTokenSerializer - logger log.Logger - metricsClient metrics.Client - taskListsLock sync.RWMutex // locks mutation of taskLists - taskLists map[tasklist.Identifier]tasklist.Manager // Convert to LRU cache - config *config.Config - lockableQueryTaskMap lockableQueryTaskMap - domainCache cache.DomainCache - versionChecker client.VersionChecker - membershipResolver membership.Resolver - isolationState isolationgroup.State - timeSource clock.TimeSource + shutdownCompletion *sync.WaitGroup + shutdown chan struct{} + taskManager persistence.TaskManager + clusterMetadata cluster.Metadata + historyService history.Client + matchingClient matching.Client + tokenSerializer common.TaskTokenSerializer + logger log.Logger + metricsClient metrics.Client + taskListsLock sync.RWMutex // locks mutation of taskLists + taskLists map[tasklist.Identifier]tasklist.Manager // Convert to LRU cache + config *config.Config + lockableQueryTaskMap lockableQueryTaskMap + domainCache cache.DomainCache + versionChecker client.VersionChecker + membershipResolver membership.Resolver + isolationState isolationgroup.State + timeSource clock.TimeSource + failoverNotificationVersion int64 } // HistoryInfo consists of two integer regarding the history size and history count @@ -162,6 +163,7 @@ func NewEngine( } func (e *matchingEngineImpl) Start() { + e.registerDomainFailoverCallback() } func (e *matchingEngineImpl) Stop() { @@ -170,6 +172,7 @@ func (e *matchingEngineImpl) Stop() { for _, l := range e.getTaskLists(math.MaxInt32) { l.Stop() } + e.unregisterDomainFailoverCallback() e.shutdownCompletion.Wait() } @@ -535,7 +538,7 @@ pollLoop: pollerCtx = tasklist.ContextWithIsolationGroup(pollerCtx, req.GetIsolationGroup()) tlMgr, err := e.getTaskListManager(taskListID, taskListKind) if err != nil { - return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) + return nil, fmt.Errorf("couldn't load tasklist manager: %w", err) } startT := time.Now() // Record the start time task, err := tlMgr.GetTask(pollerCtx, nil) @@ -724,7 +727,7 @@ pollLoop: taskListKind := request.TaskList.Kind tlMgr, err := e.getTaskListManager(taskListID, taskListKind) if err != nil { - return nil, fmt.Errorf("couldn't load tasklist namanger: %w", err) + return nil, fmt.Errorf("couldn't load tasklist manager: %w", err) } startT := time.Now() // Record the start time task, err := tlMgr.GetTask(pollerCtx, maxDispatch) @@ -1425,6 +1428,82 @@ func (e *matchingEngineImpl) isShuttingDown() bool { } } +func (e *matchingEngineImpl) domainChangeCallback(nextDomains []*cache.DomainCacheEntry) { + newFailoverNotificationVersion := e.failoverNotificationVersion + + for _, domain := range nextDomains { + if domain.GetFailoverNotificationVersion() > newFailoverNotificationVersion { + newFailoverNotificationVersion = domain.GetFailoverNotificationVersion() + } + + if !isDomainEligibleToDisconnectPollers(domain, e.failoverNotificationVersion) { + continue + } + + req := &types.GetTaskListsByDomainRequest{ + Domain: domain.GetInfo().Name, + } + + resp, err := e.GetTaskListsByDomain(nil, req) + if err != nil { + continue + } + + for taskListName := range resp.DecisionTaskListMap { + e.disconnectTaskListPollersAfterDomainFailover(taskListName, domain, persistence.TaskListTypeDecision) + } + + for taskListName := range resp.ActivityTaskListMap { + e.disconnectTaskListPollersAfterDomainFailover(taskListName, domain, persistence.TaskListTypeActivity) + } + } + e.failoverNotificationVersion = newFailoverNotificationVersion +} + +func (e *matchingEngineImpl) registerDomainFailoverCallback() { + catchUpFn := func(domainCache cache.DomainCache, _ cache.PrepareCallbackFn, _ cache.CallbackFn) { + for _, domain := range domainCache.GetAllDomain() { + if domain.GetFailoverNotificationVersion() > e.failoverNotificationVersion { + e.failoverNotificationVersion = domain.GetFailoverNotificationVersion() + } + } + } + + e.domainCache.RegisterDomainChangeCallback( + service.Matching, + catchUpFn, + func() {}, + e.domainChangeCallback) +} + +func (e *matchingEngineImpl) unregisterDomainFailoverCallback() { + e.domainCache.UnregisterDomainChangeCallback(service.Matching) +} + +func (e *matchingEngineImpl) disconnectTaskListPollersAfterDomainFailover(taskListName string, domain *cache.DomainCacheEntry, taskType int) { + taskList, err := tasklist.NewIdentifier(domain.GetInfo().ID, taskListName, taskType) + if err != nil { + return + } + tlMgr, err := e.getTaskListManager(taskList, types.TaskListKindNormal.Ptr()) + if err != nil { + e.logger.Error("Couldn't load tasklist manager", tag.Error(err)) + return + } + + err = tlMgr.ReleaseBlockedPollers() + if err != nil { + e.logger.Error("Couldn't disconnect tasklist pollers after domain failover", + tag.Error(err), + tag.WorkflowDomainID(domain.GetInfo().ID), + tag.WorkflowDomainName(domain.GetInfo().Name), + tag.WorkflowTaskListName(taskListName), + tag.WorkflowTaskListType(taskType), + ) + return + } +} + func (m *lockableQueryTaskMap) put(key string, value chan *queryResult) { m.Lock() defer m.Unlock() @@ -1451,3 +1530,10 @@ func isMatchingRetryableError(err error) bool { } return true } + +func isDomainEligibleToDisconnectPollers(domain *cache.DomainCacheEntry, currentVersion int64) bool { + return domain.IsGlobalDomain() && + domain.GetReplicationConfig() != nil && + !domain.GetReplicationConfig().IsActiveActive() && + domain.GetFailoverNotificationVersion() > currentVersion +} diff --git a/service/matching/handler/engine_integration_test.go b/service/matching/handler/engine_integration_test.go index d450c003c0f..57aa9c95dae 100644 --- a/service/matching/handler/engine_integration_test.go +++ b/service/matching/handler/engine_integration_test.go @@ -131,6 +131,8 @@ func (s *matchingEngineSuite) SetupTest() { s.mockDomainCache.EXPECT().GetDomainByID(gomock.Any()).Return(cache.CreateDomainCacheEntry(matchingTestDomainName), nil).AnyTimes() s.mockDomainCache.EXPECT().GetDomain(gomock.Any()).Return(cache.CreateDomainCacheEntry(matchingTestDomainName), nil).AnyTimes() s.mockDomainCache.EXPECT().GetDomainName(gomock.Any()).Return(matchingTestDomainName, nil).AnyTimes() + s.mockDomainCache.EXPECT().RegisterDomainChangeCallback(service.Matching, gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + s.mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).AnyTimes() s.mockMembershipResolver = membership.NewMockResolver(s.controller) s.mockMembershipResolver.EXPECT().Lookup(gomock.Any(), gomock.Any()).Return(membership.HostInfo{}, nil).AnyTimes() s.mockMembershipResolver.EXPECT().WhoAmI().Return(membership.HostInfo{}, nil).AnyTimes() diff --git a/service/matching/handler/engine_test.go b/service/matching/handler/engine_test.go index 12f59910fce..27e112c021f 100644 --- a/service/matching/handler/engine_test.go +++ b/service/matching/handler/engine_test.go @@ -40,6 +40,7 @@ import ( "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/membership" "github.com/uber/cadence/common/metrics" + "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/service" "github.com/uber/cadence/common/types" "github.com/uber/cadence/service/matching/config" @@ -645,7 +646,11 @@ func TestWaitForQueryResult(t *testing.T) { func TestIsShuttingDown(t *testing.T) { wg := sync.WaitGroup{} wg.Add(0) + mockDomainCache := cache.NewMockDomainCache(gomock.NewController(t)) + mockDomainCache.EXPECT().RegisterDomainChangeCallback(service.Matching, gomock.Any(), gomock.Any(), gomock.Any()).Times(1) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) e := matchingEngineImpl{ + domainCache: mockDomainCache, shutdownCompletion: &wg, shutdown: make(chan struct{}), } @@ -1138,3 +1143,236 @@ func TestRefreshTaskListPartitionConfig(t *testing.T) { }) } } + +func Test_domainChangeCallback(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockDomainCache := cache.NewMockDomainCache(mockCtrl) + + clusters := []string{"cluster0", "cluster1"} + + mockTaskListManagerGlobal1 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerGlobal2 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerGlobal3 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerLocal1 := tasklist.NewMockManager(mockCtrl) + mockTaskListManagerActiveActive1 := tasklist.NewMockManager(mockCtrl) + + tlIdentifierDecisionGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeDecision) + tlIdentifierActivityGlobal1, _ := tasklist.NewIdentifier("global-domain-1-id", "global-domain-1", persistence.TaskListTypeActivity) + tlIdentifierDecisionGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "global-domain-2", persistence.TaskListTypeDecision) + tlIdentifierActivityGlobal2, _ := tasklist.NewIdentifier("global-domain-2-id", "global-domain-2", persistence.TaskListTypeActivity) + tlIdentifierActivityGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "global-domain-3", persistence.TaskListTypeActivity) + tlIdentifierDecisionGlobal3, _ := tasklist.NewIdentifier("global-domain-3-id", "global-domain-3", persistence.TaskListTypeDecision) + tlIdentifierDecisionLocal1, _ := tasklist.NewIdentifier("local-domain-1-id", "local-domain-1", persistence.TaskListTypeDecision) + tlIdentifierActivityLocal1, _ := tasklist.NewIdentifier("local-domain-1-id", "local-domain-1", persistence.TaskListTypeActivity) + tlIdentifierDecisionActiveActive1, _ := tasklist.NewIdentifier("active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeDecision) + tlIdentifierActivityActiveActive1, _ := tasklist.NewIdentifier("active-active-domain-1-id", "active-active-domain-1", persistence.TaskListTypeActivity) + + engine := &matchingEngineImpl{ + domainCache: mockDomainCache, + failoverNotificationVersion: 1, + config: defaultTestConfig(), + logger: log.NewNoop(), + taskLists: map[tasklist.Identifier]tasklist.Manager{ + *tlIdentifierDecisionGlobal1: mockTaskListManagerGlobal1, + *tlIdentifierActivityGlobal1: mockTaskListManagerGlobal1, + *tlIdentifierDecisionGlobal2: mockTaskListManagerGlobal2, + *tlIdentifierActivityGlobal2: mockTaskListManagerGlobal2, + *tlIdentifierDecisionGlobal3: mockTaskListManagerGlobal3, + *tlIdentifierActivityGlobal3: mockTaskListManagerGlobal3, + *tlIdentifierDecisionLocal1: mockTaskListManagerLocal1, + *tlIdentifierActivityLocal1: mockTaskListManagerLocal1, + *tlIdentifierDecisionActiveActive1: mockTaskListManagerActiveActive1, + *tlIdentifierActivityActiveActive1: mockTaskListManagerActiveActive1, + }, + } + + mockTaskListManagerGlobal1.EXPECT().ReleaseBlockedPollers().Times(0) + mockTaskListManagerGlobal2.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).Times(2) + mockTaskListManagerGlobal2.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(2) + mockTaskListManagerGlobal2.EXPECT().ReleaseBlockedPollers().Times(2) + mockTaskListManagerGlobal3.EXPECT().GetTaskListKind().Return(types.TaskListKindNormal).Times(2) + mockTaskListManagerGlobal3.EXPECT().DescribeTaskList(gomock.Any()).Return(&types.DescribeTaskListResponse{}).Times(2) + mockTaskListManagerGlobal3.EXPECT().ReleaseBlockedPollers().Return(errors.New("some-error")).Times(2) + mockTaskListManagerLocal1.EXPECT().ReleaseBlockedPollers().Times(0) + mockTaskListManagerActiveActive1.EXPECT().ReleaseBlockedPollers().Times(0) + mockDomainCache.EXPECT().GetDomainID("global-domain-2").Return("global-domain-2-id", nil).Times(1) + mockDomainCache.EXPECT().GetDomainID("global-domain-3").Return("global-domain-3-id", nil).Times(1) + + domains := []*cache.DomainCacheEntry{ + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "global-domain-1", ID: "global-domain-1-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: clusters[0], Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "cluster0"}, {ClusterName: "cluster1"}}}, + 0, + nil, + 0, + 0, + 6, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "global-domain-2", ID: "global-domain-2-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: clusters[1], Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "cluster0"}, {ClusterName: "cluster1"}}}, + 0, + nil, + 4, + 0, + 4, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "global-domain-3", ID: "global-domain-3-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: clusters[1], Clusters: []*persistence.ClusterReplicationConfig{{ClusterName: "cluster0"}, {ClusterName: "cluster1"}}}, + 0, + nil, + 5, + 0, + 5, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "local-domain-1", ID: "local-domain-1-id"}, + nil, + false, + nil, + 0, + nil, + 0, + 0, + 3, + ), + cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{Name: "active-active-domain-1", ID: "active-active-domain-1-id"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusters: &persistence.ActiveClustersConfig{}}, + 0, + nil, + 0, + 0, + 2, + ), + } + + engine.domainChangeCallback(domains) + + assert.Equal(t, int64(5), engine.failoverNotificationVersion) +} + +func Test_registerDomainFailoverCallback(t *testing.T) { + ctrl := gomock.NewController(t) + + mockDomainCache := cache.NewMockDomainCache(ctrl) + + // Capture the registered catchUpFn + var registeredCatchUpFn func(cache.DomainCache, cache.PrepareCallbackFn, cache.CallbackFn) + mockDomainCache.EXPECT().RegisterDomainChangeCallback( + service.Matching, // id of the callback + gomock.Any(), // catchUpFn + gomock.Any(), // lockTaskProcessingForDomainUpdate + gomock.Any(), // domainChangeCB + ).Do(func(_ string, catchUpFn, _, _ interface{}) { + if fn, ok := catchUpFn.(cache.CatchUpFn); ok { + registeredCatchUpFn = fn + } else { + t.Fatalf("Failed to convert catchUpFn to cache.CatchUpFn: got type %T", catchUpFn) + } + }).Times(1) + + engine := &matchingEngineImpl{ + domainCache: mockDomainCache, + failoverNotificationVersion: 0, + config: defaultTestConfig(), + logger: log.NewNoop(), + taskLists: map[tasklist.Identifier]tasklist.Manager{}, + } + + engine.registerDomainFailoverCallback() + + t.Run("catchUpFn - No failoverNotificationVersion updates", func(t *testing.T) { + mockDomainCache.EXPECT().GetAllDomain().Return(map[string]*cache.DomainCacheEntry{ + "uuid-domain1": cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{ID: "uuid-domain1", Name: "domain1"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: "A"}, + 0, + nil, + 0, + 0, + 1, + ), + "uuid-domain2": cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{ID: "uuid-domain2", Name: "domain2"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: "A"}, + 0, + nil, + 0, + 0, + 4, + ), + }) + + prepareCalled := false + callbackCalled := false + prepare := func() { prepareCalled = true } + callback := func([]*cache.DomainCacheEntry) { callbackCalled = true } + + if registeredCatchUpFn != nil { + registeredCatchUpFn(mockDomainCache, prepare, callback) + assert.False(t, prepareCalled, "prepareCallback should not be called") + assert.False(t, callbackCalled, "callback should not be called") + } else { + assert.Fail(t, "catchUpFn was not registered") + } + + assert.Equal(t, int64(0), engine.failoverNotificationVersion) + }) + + t.Run("catchUpFn - No failoverNotificationVersion updates", func(t *testing.T) { + mockDomainCache.EXPECT().GetAllDomain().Return(map[string]*cache.DomainCacheEntry{ + "uuid-domain1": cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{ID: "uuid-domain1", Name: "domain1"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: "A"}, + 0, + nil, + 3, + 0, + 3, + ), + "uuid-domain2": cache.NewDomainCacheEntryForTest( + &persistence.DomainInfo{ID: "uuid-domain2", Name: "domain2"}, + nil, + true, + &persistence.DomainReplicationConfig{ActiveClusterName: "A"}, + 0, + nil, + 2, + 0, + 4, + ), + }) + + prepareCalled := false + callbackCalled := false + prepare := func() { prepareCalled = true } + callback := func([]*cache.DomainCacheEntry) { callbackCalled = true } + + if registeredCatchUpFn != nil { + registeredCatchUpFn(mockDomainCache, prepare, callback) + assert.False(t, prepareCalled, "prepareCallback should not be called") + assert.False(t, callbackCalled, "callback should not be called") + } else { + assert.Fail(t, "catchUpFn was not registered") + } + + assert.Equal(t, int64(3), engine.failoverNotificationVersion) + }) + +} diff --git a/service/matching/handler/handler.go b/service/matching/handler/handler.go index 082503c7763..9f6d76cb538 100644 --- a/service/matching/handler/handler.go +++ b/service/matching/handler/handler.go @@ -89,6 +89,7 @@ func NewHandler( // Start starts the handler func (h *handlerImpl) Start() { + h.engine.Start() h.startWG.Done() } diff --git a/service/matching/handler/handler_test.go b/service/matching/handler/handler_test.go index aaebd9b8eb1..f2f3112f2c4 100644 --- a/service/matching/handler/handler_test.go +++ b/service/matching/handler/handler_test.go @@ -123,6 +123,8 @@ func (s *handlerSuite) TestStart() { cfg := config.NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewInMemoryClient(), s.mockResource.Logger), "matching-test", getIsolationGroupsHelper) handler := s.getHandler(cfg) + s.mockEngine.EXPECT().Start().Times(1) + handler.Start() } @@ -132,6 +134,7 @@ func (s *handlerSuite) TestStop() { cfg := config.NewConfig(dynamicconfig.NewCollection(dynamicconfig.NewInMemoryClient(), s.mockResource.Logger), "matching-test", getIsolationGroupsHelper) handler := s.getHandler(cfg) + s.mockEngine.EXPECT().Start().Times(1) s.mockEngine.EXPECT().Stop().Times(1) handler.Start() diff --git a/service/matching/handler/membership_test.go b/service/matching/handler/membership_test.go index 21720f8d8cc..84c926b8c28 100644 --- a/service/matching/handler/membership_test.go +++ b/service/matching/handler/membership_test.go @@ -207,14 +207,17 @@ func TestSubscriptionAndShutdown(t *testing.T) { shutdownWG := &sync.WaitGroup{} shutdownWG.Add(1) + mockDomainCache := cache.NewMockDomainCache(ctrl) + e := matchingEngineImpl{ shutdownCompletion: shutdownWG, membershipResolver: m, config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, } // anytimes here because this is quite a racy test and the actual assertions for the unsubscription logic will be separated out @@ -228,6 +231,7 @@ func TestSubscriptionAndShutdown(t *testing.T) { } inc <- &m }) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) go func() { // then call stop so the test can finish @@ -242,6 +246,8 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { ctrl := gomock.NewController(t) m := membership.NewMockResolver(ctrl) + mockDomainCache := cache.NewMockDomainCache(ctrl) + shutdownWG := sync.WaitGroup{} shutdownWG.Add(1) @@ -251,8 +257,9 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, } // this should trigger the error case on a membership event @@ -268,6 +275,8 @@ func TestSubscriptionAndErrorReturned(t *testing.T) { inc <- &m }) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) + go func() { // then call stop so the test can finish time.Sleep(time.Second) @@ -281,6 +290,8 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { ctrl := gomock.NewController(t) m := membership.NewMockResolver(ctrl) + mockDomainCache := cache.NewMockDomainCache(ctrl) + logger, logs := testlogger.NewObserved(t) shutdownWG := sync.WaitGroup{} @@ -292,8 +303,9 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: logger, + shutdown: make(chan struct{}), + logger: logger, + domainCache: mockDomainCache, } // this should trigger the error case on a membership event @@ -302,6 +314,8 @@ func TestSubscribeToMembershipChangesQuitsIfSubscribeFails(t *testing.T) { m.EXPECT().Subscribe(service.Matching, "matching-engine", gomock.Any()). Return(errors.New("matching-engine is already subscribed to updates")) + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).AnyTimes() + go func() { // then call stop so the test can finish time.Sleep(time.Second) @@ -324,9 +338,12 @@ func TestGetTasklistManagerShutdownScenario(t *testing.T) { ctrl := gomock.NewController(t) m := membership.NewMockResolver(ctrl) + mockDomainCache := cache.NewMockDomainCache(ctrl) + self := membership.NewDetailedHostInfo("self", "self", nil) m.EXPECT().WhoAmI().Return(self, nil).AnyTimes() + mockDomainCache.EXPECT().UnregisterDomainChangeCallback(service.Matching).Times(1) shutdownWG := sync.WaitGroup{} shutdownWG.Add(0) @@ -337,8 +354,9 @@ func TestGetTasklistManagerShutdownScenario(t *testing.T) { config: &config.Config{ EnableTasklistOwnershipGuard: func(opts ...dynamicproperties.FilterOption) bool { return true }, }, - shutdown: make(chan struct{}), - logger: log.NewNoop(), + shutdown: make(chan struct{}), + logger: log.NewNoop(), + domainCache: mockDomainCache, } // set this engine to be shutting down so as to trigger the tasklistGetTasklistByID guard diff --git a/service/matching/tasklist/interfaces.go b/service/matching/tasklist/interfaces.go index f7dec24a953..ad65e98ac44 100644 --- a/service/matching/tasklist/interfaces.go +++ b/service/matching/tasklist/interfaces.go @@ -64,6 +64,7 @@ type ( UpdateTaskListPartitionConfig(context.Context, *types.TaskListPartitionConfig) error RefreshTaskListPartitionConfig(context.Context, *types.TaskListPartitionConfig) error LoadBalancerHints() *types.LoadBalancerHints + ReleaseBlockedPollers() error } TaskMatcher interface { @@ -76,6 +77,7 @@ type ( PollForQuery(ctx context.Context) (*InternalTask, error) UpdateRatelimit(rps *float64) Rate() float64 + RefreshCancelContext() } Forwarder interface { diff --git a/service/matching/tasklist/interfaces_mock.go b/service/matching/tasklist/interfaces_mock.go index 67c14650bb3..e3a769469d6 100644 --- a/service/matching/tasklist/interfaces_mock.go +++ b/service/matching/tasklist/interfaces_mock.go @@ -198,6 +198,20 @@ func (mr *MockManagerMockRecorder) RefreshTaskListPartitionConfig(arg0, arg1 any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTaskListPartitionConfig", reflect.TypeOf((*MockManager)(nil).RefreshTaskListPartitionConfig), arg0, arg1) } +// ReleaseBlockedPollers mocks base method. +func (m *MockManager) ReleaseBlockedPollers() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReleaseBlockedPollers") + ret0, _ := ret[0].(error) + return ret0 +} + +// ReleaseBlockedPollers indicates an expected call of ReleaseBlockedPollers. +func (mr *MockManagerMockRecorder) ReleaseBlockedPollers() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseBlockedPollers", reflect.TypeOf((*MockManager)(nil).ReleaseBlockedPollers)) +} + // Start mocks base method. func (m *MockManager) Start() error { m.ctrl.T.Helper() @@ -419,6 +433,18 @@ func (mr *MockTaskMatcherMockRecorder) Rate() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rate", reflect.TypeOf((*MockTaskMatcher)(nil).Rate)) } +// RefreshCancelContext mocks base method. +func (m *MockTaskMatcher) RefreshCancelContext() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "RefreshCancelContext") +} + +// RefreshCancelContext indicates an expected call of RefreshCancelContext. +func (mr *MockTaskMatcherMockRecorder) RefreshCancelContext() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshCancelContext", reflect.TypeOf((*MockTaskMatcher)(nil).RefreshCancelContext)) +} + // UpdateRatelimit mocks base method. func (m *MockTaskMatcher) UpdateRatelimit(rps *float64) { m.ctrl.T.Helper() diff --git a/service/matching/tasklist/matcher.go b/service/matching/tasklist/matcher.go index ad50c1825f7..f98305b23f6 100644 --- a/service/matching/tasklist/matcher.go +++ b/service/matching/tasklist/matcher.go @@ -24,6 +24,7 @@ import ( "context" "errors" "fmt" + "sync" "time" "go.uber.org/atomic" @@ -67,6 +68,7 @@ type taskMatcherImpl struct { cancelCtx context.Context // used to cancel long polling cancelFunc context.CancelFunc + cancelLock sync.Mutex tasklist *Identifier tasklistKind types.TaskListKind @@ -464,10 +466,14 @@ func (tm *taskMatcherImpl) PollForQuery(ctx context.Context) (*InternalTask, err tm.scope.RecordTimer(metrics.PollLocalMatchLatencyPerTaskList, time.Since(startT)) return task, nil } + + ctxWithCancelPropagation, stopFn := ctxutils.WithPropagatedContextCancel(ctx, tm.cancelCtx) + defer stopFn() + // there is no local poller available to pickup this task. Now block waiting // either for a local poller or a forwarding token to be available. When a // forwarding token becomes available, send this poll to a parent partition - return tm.pollOrForward(ctx, startT, "", nil, nil, tm.queryTaskC) + return tm.pollOrForward(ctxWithCancelPropagation, startT, "", nil, nil, tm.queryTaskC) } // UpdateRatelimit updates the task dispatch rate @@ -488,6 +494,12 @@ func (tm *taskMatcherImpl) Rate() float64 { return rate } +func (tm *taskMatcherImpl) RefreshCancelContext() { + tm.cancelLock.Lock() + defer tm.cancelLock.Unlock() + tm.cancelCtx, tm.cancelFunc = context.WithCancel(context.Background()) +} + func (tm *taskMatcherImpl) pollOrForward( ctx context.Context, startT time.Time, diff --git a/service/matching/tasklist/matcher_test.go b/service/matching/tasklist/matcher_test.go index 9f0ab19d696..2d2d9e91a01 100644 --- a/service/matching/tasklist/matcher_test.go +++ b/service/matching/tasklist/matcher_test.go @@ -568,12 +568,28 @@ func (t *MatcherTestSuite) TestPollersDisconnectedAfterDisconnectBlockedPollers( t.disableRemoteForwarding() t.matcher.DisconnectBlockedPollers() - longPollingCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + longPollingCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() task, err := t.matcher.Poll(longPollingCtx, "") t.ErrorIs(err, ErrNoTasks, "closed matcher should result in no tasks") t.Nil(task) + t.NoError(longPollingCtx.Err(), "the parent context was not cancelled, the child context was cancelled") +} + +func (t *MatcherTestSuite) TestPollersConnectedAfterDisconnectBlockedPollersSetCancelContext() { + defer goleak.VerifyNone(t.T()) + t.disableRemoteForwarding() + t.matcher.DisconnectBlockedPollers() + t.matcher.RefreshCancelContext() + + longPollingCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + task, err := t.matcher.Poll(longPollingCtx, "") + t.ErrorIs(err, ErrNoTasks, "no tasks to be matched to poller") + t.Nil(task) + t.ErrorIs(longPollingCtx.Err(), context.DeadlineExceeded, "the child context wasn't cancelled, the parent context timed out") } func (t *MatcherTestSuite) TestRemotePoll() { diff --git a/service/matching/tasklist/task_list_manager.go b/service/matching/tasklist/task_list_manager.go index 2bea1116278..bd72838c9b5 100644 --- a/service/matching/tasklist/task_list_manager.go +++ b/service/matching/tasklist/task_list_manager.go @@ -119,6 +119,7 @@ type ( startWG sync.WaitGroup // ensures that background processes do not start until setup is ready stopWG sync.WaitGroup stopped int32 + stoppedLock sync.RWMutex closeCallback func(Manager) throttleRetry *backoff.ThrottleRetry @@ -302,6 +303,8 @@ func (c *taskListManagerImpl) Start() error { // Stop stops task list manager and calls Stop on all background child objects func (c *taskListManagerImpl) Stop() { + c.stoppedLock.Lock() + defer c.stoppedLock.Unlock() if !atomic.CompareAndSwapInt32(&c.stopped, 0, 1) { return } @@ -757,6 +760,21 @@ func (c *taskListManagerImpl) DescribeTaskList(includeTaskListStatus bool) *type return response } +func (c *taskListManagerImpl) ReleaseBlockedPollers() error { + c.stoppedLock.RLock() + defer c.stoppedLock.RUnlock() + + if atomic.LoadInt32(&c.stopped) == 1 { + c.logger.Info("Task list manager is already stopped") + return errShutdown + } + + c.matcher.DisconnectBlockedPollers() + c.matcher.RefreshCancelContext() + + return nil +} + func (c *taskListManagerImpl) String() string { buf := new(bytes.Buffer) if c.taskListID.GetType() == persistence.TaskListTypeActivity { diff --git a/service/matching/tasklist/task_list_manager_test.go b/service/matching/tasklist/task_list_manager_test.go index ca5549420fa..0ff6c01172e 100644 --- a/service/matching/tasklist/task_list_manager_test.go +++ b/service/matching/tasklist/task_list_manager_test.go @@ -1822,6 +1822,53 @@ func TestGetNumPartitions(t *testing.T) { assert.NotPanics(t, func() { tlm.matcher.UpdateRatelimit(common.Ptr(float64(100))) }) } +func TestDisconnectBlockedPollers(t *testing.T) { + tests := []struct { + name string + newActiveClusterName *string + mockSetup func(mockMatcher *MockTaskMatcher) + stopped int32 + expectedErr error + }{ + { + name: "disconnect blocked pollers and refresh cancel context", + newActiveClusterName: common.StringPtr("new-active-cluster"), + mockSetup: func(mockMatcher *MockTaskMatcher) { + mockMatcher.EXPECT().DisconnectBlockedPollers().Times(1) + mockMatcher.EXPECT().RefreshCancelContext().Times(1) + }, + expectedErr: nil, + }, + { + name: "tasklist manager is shutting down, noop", + newActiveClusterName: common.StringPtr("new-active-cluster"), + mockSetup: func(mockMatcher *MockTaskMatcher) {}, + stopped: int32(1), + expectedErr: errShutdown, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tlID, err := NewIdentifier("domain-id", "tl", persistence.TaskListTypeDecision) + require.NoError(t, err) + + tlm, _ := setupMocksForTaskListManager(t, tlID, types.TaskListKindNormal) + + mockMatcher := NewMockTaskMatcher(gomock.NewController(t)) + tlm.matcher = mockMatcher + + tc.mockSetup(mockMatcher) + + tlm.stopped = tc.stopped + + err = tlm.ReleaseBlockedPollers() + + assert.Equal(t, tc.expectedErr, err) + }) + } +} + func partitions(num int) map[int]*types.TaskListPartition { result := make(map[int]*types.TaskListPartition, num) for i := 0; i < num; i++ { diff --git a/simulation/replication/testdata/replication_simulation_default.yaml b/simulation/replication/testdata/replication_simulation_default.yaml index 658b1140d41..59c2284bcab 100644 --- a/simulation/replication/testdata/replication_simulation_default.yaml +++ b/simulation/replication/testdata/replication_simulation_default.yaml @@ -29,7 +29,7 @@ operations: # failoverTimeoutSec: 5 # unset means force failover. setting it means graceful failover request - op: validate - at: 120s # todo: this should work at 40s mark + at: 40s workflowID: wf1 cluster: cluster1 want: