Skip to content

Commit d635ac7

Browse files
Make IsSticky() accurate for mixed-mode workflow pollers
Move the sticky/normal poll decision before slot reservation so that SlotSupplier.ReserveSlot receives an accurate IsSticky() value even when the worker uses a mixed-mode poller (the default for non-autoscaling workers). Key changes: - Add pollHint struct to carry pre-decided sticky/normal choice - Add decideNextPollKind() / undoPollDecision() on workflowTaskPoller - Modify getNextPollRequest() to accept a *pollHint and skip inline re-decision when a hint is provided - Thread the hint through taskPoller.PollTask -> poll -> getNextPollRequest - Handle cancellation cleanup (undo counters if reservation fails) - Expand test coverage for counter lifecycle and hint propagation Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent d7a7af3 commit d635ac7

6 files changed

Lines changed: 297 additions & 32 deletions

File tree

internal/internal_nexus_task_poller.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (ntp *nexusTaskPoller) Cleanup() error {
9898
}
9999

100100
// PollTask polls a new task
101-
func (ntp *nexusTaskPoller) PollTask() (taskForWorker, error) {
101+
func (ntp *nexusTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) {
102102
return ntp.doPoll(ntp.poll)
103103
}
104104

internal/internal_task_handlers_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,7 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_Workflow() {
18321832
laTaskPoller := newLocalActivityPoller(params, laTunnel, nil, nil, stopCh)
18331833
go func() {
18341834
for {
1835-
task, _ := laTaskPoller.PollTask()
1835+
task, _ := laTaskPoller.PollTask(nil)
18361836
_ = laTaskPoller.ProcessTask(task)
18371837
// Quit after we've polled enough times
18381838
if laFailures.Load() == 4 {
@@ -1915,7 +1915,7 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_WorkflowTaskHeartbeatFail
19151915
doneCh := make(chan struct{})
19161916
go func() {
19171917
// laTaskPoller needs to poll the local activity and process it
1918-
task, err := laTaskPoller.PollTask()
1918+
task, err := laTaskPoller.PollTask(nil)
19191919
t.NoError(err)
19201920
err = laTaskPoller.ProcessTask(task)
19211921
t.NoError(err)

internal/internal_task_pollers.go

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,19 @@ const (
4848
Sticky
4949
)
5050

51+
// pollHint carries a pre-decided sticky/normal decision from runPoller to PollTask.
52+
// When non-nil, the workflow task poller uses the hint instead of re-deciding.
53+
type pollHint struct {
54+
isSticky bool
55+
}
56+
5157
type (
5258
// taskPoller interface to poll for tasks
5359
taskPoller interface {
54-
// PollTask polls for one new task
55-
PollTask() (taskForWorker, error)
60+
// PollTask polls for one new task. The hint, when non-nil, carries a
61+
// pre-decided sticky/normal choice for workflow task pollers. Non-workflow
62+
// pollers should ignore it.
63+
PollTask(hint *pollHint) (taskForWorker, error)
5664
// Called when the poller will no longer be polled. Presently only useful for
5765
// workflow workers.
5866
Cleanup() error
@@ -373,10 +381,13 @@ func (wtp *workflowTaskPoller) Cleanup() error {
373381
return err
374382
}
375383

376-
// PollTask polls a new task
377-
func (wtp *workflowTaskPoller) PollTask() (taskForWorker, error) {
384+
// PollTask polls a new task. If hint is non-nil, the pre-decided sticky/normal
385+
// choice is used instead of re-deciding in getNextPollRequest.
386+
func (wtp *workflowTaskPoller) PollTask(hint *pollHint) (taskForWorker, error) {
378387
// Get the task.
379-
workflowTask, err := wtp.doPoll(wtp.poll)
388+
workflowTask, err := wtp.doPoll(func(ctx context.Context) (taskForWorker, error) {
389+
return wtp.poll(ctx, hint)
390+
})
380391
if err != nil {
381392
return nil, err
382393
}
@@ -727,7 +738,7 @@ func (latp *localActivityTaskPoller) Cleanup() error {
727738
return nil
728739
}
729740

730-
func (latp *localActivityTaskPoller) PollTask() (taskForWorker, error) {
741+
func (latp *localActivityTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) {
731742
return latp.laTunnel.getTask(), nil
732743
}
733744

@@ -901,16 +912,50 @@ func (wtp *workflowTaskPoller) updateBacklog(taskQueueKind enumspb.TaskQueueKind
901912
wtp.requestLock.Unlock()
902913
}
903914

915+
// decideNextPollKind commits to a sticky/normal decision for the next poll.
916+
// It increments the appropriate pending counter and returns whether the poll
917+
// should target the sticky queue. Only meaningful in Mixed mode; for Sticky and
918+
// NonSticky modes the answer is deterministic and no counter is changed.
919+
func (wtp *workflowTaskPoller) decideNextPollKind() (isSticky bool) {
920+
if wtp.mode != Mixed || wtp.stickyCacheSize <= 0 {
921+
return wtp.mode == Sticky
922+
}
923+
wtp.requestLock.Lock()
924+
defer wtp.requestLock.Unlock()
925+
if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount {
926+
wtp.pendingStickyPollCount++
927+
return true
928+
}
929+
wtp.pendingRegularPollCount++
930+
return false
931+
}
932+
933+
// undoPollDecision reverses the counter increment made by decideNextPollKind.
934+
// Call this when a pre-decided poll will not happen (e.g. slot reservation
935+
// was cancelled or failed).
936+
func (wtp *workflowTaskPoller) undoPollDecision(isSticky bool) {
937+
if wtp.mode != Mixed || wtp.stickyCacheSize <= 0 {
938+
return
939+
}
940+
if isSticky {
941+
wtp.release(enumspb.TASK_QUEUE_KIND_STICKY)
942+
} else {
943+
wtp.release(enumspb.TASK_QUEUE_KIND_NORMAL)
944+
}
945+
}
946+
904947
// getNextPollRequest returns appropriate next poll request based on poller configuration and mode.
905948
// Simple rules:
906949
// 1. if mode is NonSticky, always poll from regular task queue
907950
// 2. if mode is Sticky, always poll from sticky task queue
908951
// 3. if mode is Mixed
909952
// 3.1. if sticky execution is disabled, always poll for regular task queue
910-
// 3.2. otherwise:
911-
// 3.2.1) if sticky task queue has backlog, always prefer to process sticky task first
912-
// 3.2.2) poll from the task queue that has less pending requests (prefer sticky when they are the same).
913-
func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.PollWorkflowTaskQueueRequest) {
953+
// 3.2. otherwise, if a hint is provided, use the pre-decided kind (counter
954+
// was already incremented by decideNextPollKind)
955+
// 3.3. otherwise (no hint / nil hint):
956+
// 3.3.1) if sticky task queue has backlog, always prefer to process sticky task first
957+
// 3.3.2) poll from the task queue that has less pending requests (prefer sticky when they are the same).
958+
func (wtp *workflowTaskPoller) getNextPollRequest(hint *pollHint) (request *workflowservice.PollWorkflowTaskQueueRequest) {
914959
taskQueue := &taskqueuepb.TaskQueue{
915960
Name: wtp.taskQueueName,
916961
Kind: enumspb.TASK_QUEUE_KIND_NORMAL,
@@ -923,16 +968,27 @@ func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.Po
923968
taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY
924969
taskQueue.NormalName = wtp.taskQueueName
925970
} else if wtp.mode == Mixed {
926-
wtp.requestLock.Lock()
927-
if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount {
928-
wtp.pendingStickyPollCount++
929-
taskQueue.Name = getWorkerTaskQueue(wtp.stickyUUID)
930-
taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY
931-
taskQueue.NormalName = wtp.taskQueueName
971+
if hint != nil {
972+
// Use the pre-decided kind; counter was already incremented
973+
// by decideNextPollKind.
974+
if hint.isSticky {
975+
taskQueue.Name = getWorkerTaskQueue(wtp.stickyUUID)
976+
taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY
977+
taskQueue.NormalName = wtp.taskQueueName
978+
}
932979
} else {
933-
wtp.pendingRegularPollCount++
980+
// Fallback: decide inline (original behavior).
981+
wtp.requestLock.Lock()
982+
if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount {
983+
wtp.pendingStickyPollCount++
984+
taskQueue.Name = getWorkerTaskQueue(wtp.stickyUUID)
985+
taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY
986+
taskQueue.NormalName = wtp.taskQueueName
987+
} else {
988+
wtp.pendingRegularPollCount++
989+
}
990+
wtp.requestLock.Unlock()
934991
}
935-
wtp.requestLock.Unlock()
936992
} else {
937993
panic("unknown workflow task poller mode")
938994
}
@@ -973,12 +1029,12 @@ func (wtp *workflowTaskPoller) pollWorkflowTaskQueue(ctx context.Context, reques
9731029
}
9741030

9751031
// Poll for a single workflow task from the service
976-
func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error) {
1032+
func (wtp *workflowTaskPoller) poll(ctx context.Context, hint *pollHint) (taskForWorker, error) {
9771033
traceLog(func() {
9781034
wtp.logger.Debug("workflowTaskPoller::Poll")
9791035
})
9801036

981-
request := wtp.getNextPollRequest()
1037+
request := wtp.getNextPollRequest(hint)
9821038
defer wtp.release(request.TaskQueue.GetKind())
9831039

9841040
response, err := wtp.pollWorkflowTaskQueue(ctx, request)
@@ -1207,7 +1263,7 @@ func (atp *activityTaskPoller) Cleanup() error {
12071263
}
12081264

12091265
// PollTask polls a new task
1210-
func (atp *activityTaskPoller) PollTask() (taskForWorker, error) {
1266+
func (atp *activityTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) {
12111267
// Get the task.
12121268
activityTask, err := atp.doPoll(atp.poll)
12131269
if err != nil {

internal/internal_worker_base.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,16 +444,26 @@ func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) {
444444
}
445445
}
446446

447+
// Pre-decide sticky/normal for workflow task pollers so that
448+
// IsSticky() is accurate at slot reservation time.
447449
data := bw.options.slotReservationData
448-
if wtp, ok := taskWorker.taskPoller.(*workflowTaskPoller); ok && wtp.mode == Sticky {
449-
data.isSticky = true
450+
var hint *pollHint
451+
wtp, isWorkflowPoller := taskWorker.taskPoller.(*workflowTaskPoller)
452+
if isWorkflowPoller {
453+
isSticky := wtp.decideNextPollKind()
454+
data.isSticky = isSticky
455+
hint = &pollHint{isSticky: isSticky}
450456
}
451457

452458
bw.stopWG.Add(1)
453459
go func() {
454460
defer bw.stopWG.Done()
455461
s, err := bw.slotSupplier.ReserveSlot(ctx, &data)
456462
if err != nil {
463+
// Undo pre-decided counter since the poll will not happen.
464+
if isWorkflowPoller {
465+
wtp.undoPollDecision(hint.isSticky)
466+
}
457467
if !errors.Is(err, context.Canceled) {
458468
bw.logger.Error("Error while trying to reserve slot", "error", err)
459469
select {
@@ -467,6 +477,10 @@ func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) {
467477
select {
468478
case reserveChan <- s:
469479
case <-ctx.Done():
480+
// Worker stopped: undo pre-decided counter and release the slot.
481+
if isWorkflowPoller {
482+
wtp.undoPollDecision(hint.isSticky)
483+
}
470484
bw.releaseSlot(s, SlotReleaseReasonUnused)
471485
}
472486
}()
@@ -488,7 +502,7 @@ func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) {
488502
if bw.pollerBalancer != nil {
489503
bw.pollerBalancer.incrementPoller(taskWorker.taskPollerType)
490504
}
491-
bw.pollTask(taskWorker, permit)
505+
bw.pollTask(taskWorker, permit, hint)
492506
if bw.pollerBalancer != nil {
493507
bw.pollerBalancer.decrementPoller(taskWorker.taskPollerType)
494508
}
@@ -595,7 +609,7 @@ func (bw *baseWorker) runEagerTaskDispatcher() {
595609
}
596610
}
597611

598-
func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPermit) {
612+
func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPermit, hint *pollHint) {
599613
var err error
600614
var task taskForWorker
601615
didSendTask := false
@@ -607,7 +621,7 @@ func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPe
607621

608622
bw.retrier.Throttle(bw.stopCh)
609623
if bw.pollLimiter == nil || bw.pollLimiter.Wait(bw.limiterContext) == nil {
610-
task, err = taskWorker.taskPoller.PollTask()
624+
task, err = taskWorker.taskPoller.PollTask(hint)
611625
bw.logPollTaskError(err)
612626
if err != nil {
613627
// We retry "non retriable" errors while long polling for a while, because some proxies return

internal/tuning.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ type SlotReservationInfo interface {
6060
MetricsHandler() metrics.Handler
6161
// IsSticky returns true if the slot being reserved will be used to poll a sticky task queue.
6262
// This is only meaningful for workflow task slots. For activity and local activity slots,
63-
// this will always return false. When the worker is configured with a mixed-mode poller
64-
// (the default when not using autoscaling), this will also return false because the
65-
// sticky-vs-normal decision is made after the slot is reserved.
63+
// this will always return false.
6664
IsSticky() bool
6765
}
6866

0 commit comments

Comments
 (0)