diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index 72a067926..a6ad37bb9 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -315,24 +315,12 @@ func (bp *basePoller) doPoll(pollFunc func(ctx context.Context) (taskForWorker, }() if bp.workerPollCompleteOnShutdown != nil && bp.workerPollCompleteOnShutdown.Load() { - // Don't kill the gRPC stream. After ShutdownWorker, the server returns empty responses. - select { - case <-doneC: - return result, err - case <-bp.stopC: - // TEMP FIX: Give the server a reasonable window to complete the poll after - // ShutdownWorker. Fall back to cancelling the poll if it takes too - // long, e.g. when the gRPC connection was closed before Stop(). - timer := time.NewTimer(5 * time.Second) - defer timer.Stop() - select { - case <-doneC: - case <-timer.C: - cancel() - <-doneC - } - return result, err - } + // Don't cancel the gRPC stream. After ShutdownWorker, the server + // completes the poll with an empty response. The poll is bounded + // by the gRPC timeout (pollTaskServiceTimeOut). Stop() waits for + // all pollers to finish before proceeding to task drain. + <-doneC + return result, err } // Legacy: cancel in-flight polls immediately on shutdown diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 28254763d..cf8b921cf 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -216,6 +216,7 @@ type ( lastPollTaskErrLock sync.Mutex noRepoll atomic.Bool + pollerWG sync.WaitGroup } eagerOrPolledTask interface { @@ -391,6 +392,7 @@ func (bw *baseWorker) Start() { for i := 0; i < taskWorker.pollerCount; i++ { bw.stopWG.Add(1) + bw.pollerWG.Add(1) go bw.runPoller(taskWorker) } @@ -403,6 +405,15 @@ func (bw *baseWorker) Start() { } } + // When all pollers have exited, close taskQueueCh so the dispatcher + // knows no more polled tasks will arrive and can drain what remains. + bw.stopWG.Add(1) + go func() { + defer bw.stopWG.Done() + bw.pollerWG.Wait() + close(bw.taskQueueCh) + }() + bw.stopWG.Add(1) go bw.runTaskDispatcher() @@ -428,6 +439,7 @@ func (bw *baseWorker) isStop() bool { func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) { defer bw.stopWG.Done() + defer bw.pollerWG.Done() // Note: With poller autoscaling, this metric doesn't make a lot of sense since the number of pollers can go up and down. bw.metricsHandler.Counter(metrics.PollerStartCounter).Inc(1) @@ -561,24 +573,17 @@ func (bw *baseWorker) processTaskAsync(eagerOrPolled eagerOrPolledTask) { func (bw *baseWorker) runTaskDispatcher() { defer bw.stopWG.Done() - for { - // wait for new task or worker stop - select { - case <-bw.stopCh: - // Currently we can drop any tasks received when closing. - // https://github.com/temporalio/sdk-go/issues/1197 - return - case task := <-bw.taskQueueCh: - // for non-polled-task (local activity result as task or eager task), we don't need to rate limit - _, isPolledTask := task.(*polledTask) - if isPolledTask && bw.taskLimiter.Wait(bw.limiterContext) != nil { - if bw.isStop() { - bw.releaseSlot(task.getPermit(), SlotReleaseReasonUnused) - return - } - } - bw.processTaskAsync(task) + for task := range bw.taskQueueCh { + // For non-polled-task (local activity result as task or eager task), + // we don't need to rate limit. During shutdown the limiter context + // is cancelled, so Wait returns immediately — we still process the + // task rather than dropping it. + if _, isPolledTask := task.(*polledTask); isPolledTask { + // Ignore error: during shutdown the limiter context is + // cancelled, but we still process remaining tasks. + _ = bw.taskLimiter.Wait(bw.limiterContext) } + bw.processTaskAsync(task) } } @@ -639,11 +644,10 @@ func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPe taskWorker.pollerAutoscalerReportHandle.handleTask(task) } - select { - case bw.taskQueueCh <- &polledTask{task: task, permit: slotPermit}: - didSendTask = true - case <-bw.stopCh: - } + // The dispatcher is guaranteed to be alive: it only exits after + // taskQueueCh is closed, which happens after all pollers finish. + bw.taskQueueCh <- &polledTask{task: task, permit: slotPermit} + didSendTask = true } } @@ -703,6 +707,12 @@ func (bw *baseWorker) Stop() { close(bw.stopCh) bw.limiterContextCancel() + // Wait for pollers to finish. (pollTaskServiceTimeOut) bounds this if the connection is broken. + bw.pollerWG.Wait() + + // Wait for task processing to complete. The dispatcher + // drains taskQueueCh (closed after pollers finish above) and + // processTaskAsync goroutines are tracked in stopWG. if success := awaitWaitGroup(&bw.stopWG, bw.options.stopTimeout); !success { traceLog(func() { bw.logger.Info("Worker graceful stop timed out.", "Stop timeout", bw.options.stopTimeout) diff --git a/internal/internal_worker_base_test.go b/internal/internal_worker_base_test.go index 952df45eb..8cb9cf666 100644 --- a/internal/internal_worker_base_test.go +++ b/internal/internal_worker_base_test.go @@ -330,6 +330,104 @@ type noopTaskProcessor struct{} func (noopTaskProcessor) ProcessTask(any) error { return nil } +// TestTaskNotDroppedDuringShutdown verifies the two-stage shutdown: when a +// poller receives a task during shutdown, the task is still dispatched and +// processed rather than silently dropped. +func TestTaskNotDroppedDuringShutdown(t *testing.T) { + taskProcessed := make(chan struct{}) + pollStarted := make(chan struct{}) + + // A poller that blocks until returnTask is closed, then returns a task + // exactly once. Subsequent polls return nil so the poller can exit. + tp := &shutdownTaskPoller{ + pollStarted: pollStarted, + returnTask: make(chan struct{}), + task: &testTask{}, + } + + processor := &recordingTaskProcessor{ + processed: taskProcessed, + } + + bw := newBaseWorker(baseWorkerOptions{ + slotSupplier: &testSlotSupplier{}, + maxTaskPerSecond: 1000, + taskPollers: []scalableTaskPoller{ + {taskPollerType: "test", pollerCount: 1, taskPoller: tp}, + }, + taskProcessor: processor, + workerType: "ShutdownTest", + logger: ilog.NewNopLogger(), + stopTimeout: 5 * time.Second, + metricsHandler: metrics.NopHandler, + }) + + bw.Start() + + // Wait for the poller to be actively polling. + <-pollStarted + + // Release the poller so it returns a task, then stop the worker. + // The poller returns a task and then nil on subsequent polls, + // allowing it to exit via noRepoll/stopCh during Stop(). + close(tp.returnTask) + + // Stop exercises the real shutdown path: noRepoll, close(stopCh), + // limiterContextCancel, and awaitWaitGroup. + stopDone := make(chan struct{}) + go func() { + bw.Stop() + close(stopDone) + }() + + select { + case <-taskProcessed: + // Success: the task was dispatched and processed during shutdown + case <-time.After(5 * time.Second): + t.Fatal("task polled during shutdown was not processed (dropped)") + } + + select { + case <-stopDone: + // Stop completed cleanly + case <-time.After(5 * time.Second): + t.Fatal("Stop() did not return in time") + } +} + +// shutdownTaskPoller blocks until returnTask is closed, then returns a task +// exactly once. Subsequent polls return nil. +type shutdownTaskPoller struct { + pollStarted chan struct{} + returnTask chan struct{} + task taskForWorker + returned atomic.Bool +} + +func (p *shutdownTaskPoller) PollTask() (taskForWorker, error) { + select { + case p.pollStarted <- struct{}{}: + default: + } + <-p.returnTask + if p.returned.CompareAndSwap(false, true) { + return p.task, nil + } + return nil, nil +} + +type recordingTaskProcessor struct { + processed chan struct{} +} + +func (p *recordingTaskProcessor) ProcessTask(any) error { + select { + case p.processed <- struct{}{}: + default: + } + return nil +} + func (s *PollScalerReportHandleSuite) TestAutoscaleDownOnTimeoutWithCapability() { targetSuggestion := 0 ps := newPollScalerReportHandle(pollScalerReportHandleOptions{ diff --git a/test/integration_test.go b/test/integration_test.go index 67971e1d4..c641cd89c 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -4316,7 +4316,9 @@ func (ts *IntegrationTestSuite) testUpdateOrderingCancel(cancelWf bool) { }() } - // The server does not support admitted updates, so we send the update in a separate goroutine + // The server does not support admitted updates, so we send the update in a separate goroutine. + // Keep this shorter than the activity's ScheduleToCloseTimeout (5s) so the new worker + // has time to execute activities before they time out. time.Sleep(5 * time.Second) // Now create a new worker on that same task queue to resume the work of the // workflow