diff --git a/service/worker/asyncworkflow/async_workflow_consumer_manager.go b/service/worker/asyncworkflow/async_workflow_consumer_manager.go index 1e87da24366..f50becee8fd 100644 --- a/service/worker/asyncworkflow/async_workflow_consumer_manager.go +++ b/service/worker/asyncworkflow/async_workflow_consumer_manager.go @@ -71,6 +71,10 @@ func WithEmitConsumerCountMetrifFn(fn func(int)) ConsumerManagerOptions { } } +func withAfterIterFn(fn func()) ConsumerManagerOptions { + return func(c *ConsumerManager) { c.afterIterFn = fn } +} + func NewConsumerManager( logger log.Logger, metricsClient metrics.Client, @@ -119,6 +123,7 @@ type ConsumerManager struct { wg sync.WaitGroup activeConsumers map[string]provider.Consumer emitConsumerCountMetricFn func(int) + afterIterFn func() // test hook: called after each ticker iteration, nil in production } func (c *ConsumerManager) Start() { @@ -176,6 +181,9 @@ func (c *ConsumerManager) run() { c.logger.Info("ConsumerManager background loop stopped because context is done") return } + if c.afterIterFn != nil { + c.afterIterFn() + } } } diff --git a/service/worker/asyncworkflow/async_workflow_consumer_manager_test.go b/service/worker/asyncworkflow/async_workflow_consumer_manager_test.go index 93cb3940a7f..dd735a2e65c 100644 --- a/service/worker/asyncworkflow/async_workflow_consumer_manager_test.go +++ b/service/worker/asyncworkflow/async_workflow_consumer_manager_test.go @@ -349,6 +349,8 @@ func TestConsumerManagerEnabledDisabled(t *testing.T) { var consumerMgrEnabled, consumerCount int32 + refreshed := make(chan struct{}, 1) + // create consumer manager cm := NewConsumerManager( testlogger.New(t), @@ -363,6 +365,12 @@ func TestConsumerManagerEnabledDisabled(t *testing.T) { WithEmitConsumerCountMetrifFn(func(count int) { atomic.StoreInt32(&consumerCount, int32(count)) }), + withAfterIterFn(func() { + select { + case refreshed <- struct{}{}: + default: + } + }), ) cm.Start() @@ -370,8 +378,9 @@ func TestConsumerManagerEnabledDisabled(t *testing.T) { // wait for the first round of consumers to be created and verify consumer count atomic.StoreInt32(&consumerMgrEnabled, 1) + mockTimeSrc.BlockUntil(1) // wait for run() goroutine to create the ticker mockTimeSrc.Advance(defaultRefreshInterval) - time.Sleep(50 * time.Millisecond) + <-refreshed t.Log("first round comparison") got := atomic.LoadInt32(&consumerCount) want := 1 // consumer manager is enabled @@ -382,7 +391,7 @@ func TestConsumerManagerEnabledDisabled(t *testing.T) { // disable consumer manager and wait for the second round of refresh atomic.StoreInt32(&consumerMgrEnabled, 0) mockTimeSrc.Advance(defaultRefreshInterval) - time.Sleep(50 * time.Millisecond) + <-refreshed got = atomic.LoadInt32(&consumerCount) want = 0 // all consumers should be stopped when consumer manager is disabled if got != int32(want) { @@ -392,7 +401,7 @@ func TestConsumerManagerEnabledDisabled(t *testing.T) { // enable consumer manager and wait for the third round of refresh atomic.StoreInt32(&consumerMgrEnabled, 1) mockTimeSrc.Advance(defaultRefreshInterval) - time.Sleep(50 * time.Millisecond) + <-refreshed got = atomic.LoadInt32(&consumerCount) want = 1 // consumer manager is enabled if got != int32(want) {