diff --git a/internal/internal_nexus_task_poller.go b/internal/internal_nexus_task_poller.go index 6537f3454..083712825 100644 --- a/internal/internal_nexus_task_poller.go +++ b/internal/internal_nexus_task_poller.go @@ -98,7 +98,7 @@ func (ntp *nexusTaskPoller) Cleanup() error { } // PollTask polls a new task -func (ntp *nexusTaskPoller) PollTask() (taskForWorker, error) { +func (ntp *nexusTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) { return ntp.doPoll(ntp.poll) } diff --git a/internal/internal_task_handlers_test.go b/internal/internal_task_handlers_test.go index a4e142d0d..fbef3fbde 100644 --- a/internal/internal_task_handlers_test.go +++ b/internal/internal_task_handlers_test.go @@ -1832,7 +1832,7 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_Workflow() { laTaskPoller := newLocalActivityPoller(params, laTunnel, nil, nil, stopCh) go func() { for { - task, _ := laTaskPoller.PollTask() + task, _ := laTaskPoller.PollTask(nil) _ = laTaskPoller.ProcessTask(task) // Quit after we've polled enough times if laFailures.Load() == 4 { @@ -1915,7 +1915,7 @@ func (t *TaskHandlersTestSuite) TestLocalActivityRetry_WorkflowTaskHeartbeatFail doneCh := make(chan struct{}) go func() { // laTaskPoller needs to poll the local activity and process it - task, err := laTaskPoller.PollTask() + task, err := laTaskPoller.PollTask(nil) t.NoError(err) err = laTaskPoller.ProcessTask(task) t.NoError(err) diff --git a/internal/internal_task_pollers.go b/internal/internal_task_pollers.go index dba946879..9b16bd9b9 100644 --- a/internal/internal_task_pollers.go +++ b/internal/internal_task_pollers.go @@ -48,11 +48,19 @@ const ( Sticky ) +// pollHint carries a pre-decided sticky/normal decision from runPoller to PollTask. +// When non-nil, the workflow task poller uses the hint instead of re-deciding. +type pollHint struct { + isSticky bool +} + type ( // taskPoller interface to poll for tasks taskPoller interface { - // PollTask polls for one new task - PollTask() (taskForWorker, error) + // PollTask polls for one new task. The hint, when non-nil, carries a + // pre-decided sticky/normal choice for workflow task pollers. Non-workflow + // pollers should ignore it. + PollTask(hint *pollHint) (taskForWorker, error) // Called when the poller will no longer be polled. Presently only useful for // workflow workers. Cleanup() error @@ -373,10 +381,13 @@ func (wtp *workflowTaskPoller) Cleanup() error { return err } -// PollTask polls a new task -func (wtp *workflowTaskPoller) PollTask() (taskForWorker, error) { +// PollTask polls a new task. If hint is non-nil, the pre-decided sticky/normal +// choice is used instead of re-deciding in getNextPollRequest. +func (wtp *workflowTaskPoller) PollTask(hint *pollHint) (taskForWorker, error) { // Get the task. - workflowTask, err := wtp.doPoll(wtp.poll) + workflowTask, err := wtp.doPoll(func(ctx context.Context) (taskForWorker, error) { + return wtp.poll(ctx, hint) + }) if err != nil { return nil, err } @@ -741,7 +752,7 @@ func (latp *localActivityTaskPoller) Cleanup() error { return nil } -func (latp *localActivityTaskPoller) PollTask() (taskForWorker, error) { +func (latp *localActivityTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) { return latp.laTunnel.getTask(), nil } @@ -915,16 +926,50 @@ func (wtp *workflowTaskPoller) updateBacklog(taskQueueKind enumspb.TaskQueueKind wtp.requestLock.Unlock() } +// decideNextPollKind commits to a sticky/normal decision for the next poll. +// It increments the appropriate pending counter and returns whether the poll +// should target the sticky queue. Only meaningful in Mixed mode; for Sticky and +// NonSticky modes the answer is deterministic and no counter is changed. +func (wtp *workflowTaskPoller) decideNextPollKind() (isSticky bool) { + if wtp.mode != Mixed || wtp.stickyCacheSize <= 0 { + return wtp.mode == Sticky + } + wtp.requestLock.Lock() + defer wtp.requestLock.Unlock() + if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount { + wtp.pendingStickyPollCount++ + return true + } + wtp.pendingRegularPollCount++ + return false +} + +// undoPollDecision reverses the counter increment made by decideNextPollKind. +// Call this when a pre-decided poll will not happen (e.g. slot reservation +// was cancelled or failed). +func (wtp *workflowTaskPoller) undoPollDecision(isSticky bool) { + if wtp.mode != Mixed || wtp.stickyCacheSize <= 0 { + return + } + if isSticky { + wtp.release(enumspb.TASK_QUEUE_KIND_STICKY) + } else { + wtp.release(enumspb.TASK_QUEUE_KIND_NORMAL) + } +} + // getNextPollRequest returns appropriate next poll request based on poller configuration and mode. // Simple rules: // 1. if mode is NonSticky, always poll from regular task queue // 2. if mode is Sticky, always poll from sticky task queue // 3. if mode is Mixed // 3.1. if sticky execution is disabled, always poll for regular task queue -// 3.2. otherwise: -// 3.2.1) if sticky task queue has backlog, always prefer to process sticky task first -// 3.2.2) poll from the task queue that has less pending requests (prefer sticky when they are the same). -func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.PollWorkflowTaskQueueRequest) { +// 3.2. otherwise, if a hint is provided, use the pre-decided kind (counter +// was already incremented by decideNextPollKind) +// 3.3. otherwise (no hint / nil hint): +// 3.3.1) if sticky task queue has backlog, always prefer to process sticky task first +// 3.3.2) poll from the task queue that has less pending requests (prefer sticky when they are the same). +func (wtp *workflowTaskPoller) getNextPollRequest(hint *pollHint) (request *workflowservice.PollWorkflowTaskQueueRequest) { taskQueue := &taskqueuepb.TaskQueue{ Name: wtp.taskQueueName, Kind: enumspb.TASK_QUEUE_KIND_NORMAL, @@ -937,16 +982,27 @@ func (wtp *workflowTaskPoller) getNextPollRequest() (request *workflowservice.Po taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY taskQueue.NormalName = wtp.taskQueueName } else if wtp.mode == Mixed { - wtp.requestLock.Lock() - if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount { - wtp.pendingStickyPollCount++ - taskQueue.Name = getWorkerTaskQueue(wtp.stickyUUID) - taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY - taskQueue.NormalName = wtp.taskQueueName + if hint != nil { + // Use the pre-decided kind; counter was already incremented + // by decideNextPollKind. + if hint.isSticky { + taskQueue.Name = getWorkerTaskQueue(wtp.stickyUUID) + taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY + taskQueue.NormalName = wtp.taskQueueName + } } else { - wtp.pendingRegularPollCount++ + // Fallback: decide inline (original behavior). + wtp.requestLock.Lock() + if wtp.stickyBacklog > 0 || wtp.pendingStickyPollCount <= wtp.pendingRegularPollCount { + wtp.pendingStickyPollCount++ + taskQueue.Name = getWorkerTaskQueue(wtp.stickyUUID) + taskQueue.Kind = enumspb.TASK_QUEUE_KIND_STICKY + taskQueue.NormalName = wtp.taskQueueName + } else { + wtp.pendingRegularPollCount++ + } + wtp.requestLock.Unlock() } - wtp.requestLock.Unlock() } else { panic("unknown workflow task poller mode") } @@ -987,12 +1043,12 @@ func (wtp *workflowTaskPoller) pollWorkflowTaskQueue(ctx context.Context, reques } // Poll for a single workflow task from the service -func (wtp *workflowTaskPoller) poll(ctx context.Context) (taskForWorker, error) { +func (wtp *workflowTaskPoller) poll(ctx context.Context, hint *pollHint) (taskForWorker, error) { traceLog(func() { wtp.logger.Debug("workflowTaskPoller::Poll") }) - request := wtp.getNextPollRequest() + request := wtp.getNextPollRequest(hint) defer wtp.release(request.TaskQueue.GetKind()) response, err := wtp.pollWorkflowTaskQueue(ctx, request) @@ -1221,7 +1277,7 @@ func (atp *activityTaskPoller) Cleanup() error { } // PollTask polls a new task -func (atp *activityTaskPoller) PollTask() (taskForWorker, error) { +func (atp *activityTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) { // Get the task. activityTask, err := atp.doPoll(atp.poll) if err != nil { diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index 4f2ae1493..bd21ae7b7 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -442,11 +442,26 @@ func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) { } } + // Pre-decide sticky/normal for workflow task pollers so that + // IsSticky() is accurate at slot reservation time. + data := bw.options.slotReservationData + var hint *pollHint + wtp, isWorkflowPoller := taskWorker.taskPoller.(*workflowTaskPoller) + if isWorkflowPoller { + isSticky := wtp.decideNextPollKind() + data.isSticky = isSticky + hint = &pollHint{isSticky: isSticky} + } + bw.stopWG.Add(1) go func() { defer bw.stopWG.Done() - s, err := bw.slotSupplier.ReserveSlot(ctx, &bw.options.slotReservationData) + s, err := bw.slotSupplier.ReserveSlot(ctx, &data) if err != nil { + // Undo pre-decided counter since the poll will not happen. + if isWorkflowPoller { + wtp.undoPollDecision(hint.isSticky) + } if !errors.Is(err, context.Canceled) { bw.logger.Error("Error while trying to reserve slot", "error", err) select { @@ -460,6 +475,10 @@ func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) { select { case reserveChan <- s: case <-ctx.Done(): + // Worker stopped: undo pre-decided counter and release the slot. + if isWorkflowPoller { + wtp.undoPollDecision(hint.isSticky) + } bw.releaseSlot(s, SlotReleaseReasonUnused) } }() @@ -481,7 +500,7 @@ func (bw *baseWorker) runPoller(taskWorker scalableTaskPoller) { if bw.pollerBalancer != nil { bw.pollerBalancer.incrementPoller(taskWorker.taskPollerType) } - bw.pollTask(taskWorker, permit) + bw.pollTask(taskWorker, permit, hint) if bw.pollerBalancer != nil { bw.pollerBalancer.decrementPoller(taskWorker.taskPollerType) } @@ -588,7 +607,7 @@ func (bw *baseWorker) runEagerTaskDispatcher() { } } -func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPermit) { +func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPermit, hint *pollHint) { var err error var task taskForWorker didSendTask := false @@ -600,7 +619,7 @@ func (bw *baseWorker) pollTask(taskWorker scalableTaskPoller, slotPermit *SlotPe bw.retrier.Throttle(bw.stopCh) if bw.pollLimiter == nil || bw.pollLimiter.Wait(bw.limiterContext) == nil { - task, err = taskWorker.taskPoller.PollTask() + task, err = taskWorker.taskPoller.PollTask(hint) bw.logPollTaskError(err) if err != nil { // We retry "non retriable" errors while long polling for a while, because some proxies return diff --git a/internal/internal_worker_base_test.go b/internal/internal_worker_base_test.go index e900dbba7..bd1c30a77 100644 --- a/internal/internal_worker_base_test.go +++ b/internal/internal_worker_base_test.go @@ -234,7 +234,7 @@ func newSemaphoreProbeTaskPoller() *semaphoreProbeTaskPoller { } // PollTask implements taskPoller and blocks until a signal is provided so the semaphore permits stay acquired. -func (p *semaphoreProbeTaskPoller) PollTask() (taskForWorker, error) { +func (p *semaphoreProbeTaskPoller) PollTask(_ *pollHint) (taskForWorker, error) { _, ok := <-p.signals if !ok { return nil, nil diff --git a/internal/tuning.go b/internal/tuning.go index 8d146800f..f510fd2b4 100644 --- a/internal/tuning.go +++ b/internal/tuning.go @@ -58,6 +58,10 @@ type SlotReservationInfo interface { // MetricsHandler returns an appropriately tagged metrics handler that can be used to record // custom metrics. MetricsHandler() metrics.Handler + // IsSticky returns true if the slot being reserved will be used to poll a sticky task queue. + // This is only meaningful for workflow task slots. For activity and local activity slots, + // this will always return false. + IsSticky() bool } // SlotMarkUsedInfo contains information that SlotSupplier instances can use during @@ -287,6 +291,7 @@ func (f *FixedSizeSlotSupplier) MaxSlots() int { type slotReservationData struct { taskQueue string + isSticky bool } type slotReserveInfoImpl struct { @@ -296,6 +301,7 @@ type slotReserveInfoImpl struct { issuedSlots *atomic.Int32 logger log.Logger metrics metrics.Handler + sticky bool } func (s slotReserveInfoImpl) TaskQueue() string { @@ -322,6 +328,10 @@ func (s slotReserveInfoImpl) MetricsHandler() metrics.Handler { return s.metrics } +func (s slotReserveInfoImpl) IsSticky() bool { + return s.sticky +} + type slotMarkUsedContextImpl struct { permit *SlotPermit logger log.Logger @@ -410,6 +420,7 @@ func (t *trackingSlotSupplier) ReserveSlot( issuedSlots: &t.issuedSlotsAtomic, logger: t.logger, metrics: t.metrics, + sticky: data.isSticky, }) if err != nil { return nil, err @@ -433,6 +444,7 @@ func (t *trackingSlotSupplier) TryReserveSlot(data *slotReservationData) *SlotPe issuedSlots: &t.issuedSlotsAtomic, logger: t.logger, metrics: t.metrics, + sticky: data.isSticky, }) if permit != nil { t.issuedSlotsAtomic.Add(1) diff --git a/internal/tuning_test.go b/internal/tuning_test.go new file mode 100644 index 000000000..c5ef0122a --- /dev/null +++ b/internal/tuning_test.go @@ -0,0 +1,312 @@ +package internal + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + enumspb "go.temporal.io/api/enums/v1" + + "go.temporal.io/sdk/internal/common/metrics" + ilog "go.temporal.io/sdk/internal/log" +) + +func TestSlotReserveInfoImpl_IsSticky(t *testing.T) { + t.Run("sticky true", func(t *testing.T) { + info := slotReserveInfoImpl{ + taskQueue: "test-queue", + workerBuildId: "build1", + workerIdentity: "worker1", + issuedSlots: &atomic.Int32{}, + logger: ilog.NewDefaultLogger(), + metrics: metrics.NopHandler, + sticky: true, + } + assert.True(t, info.IsSticky()) + }) + + t.Run("sticky false", func(t *testing.T) { + info := slotReserveInfoImpl{ + taskQueue: "test-queue", + workerBuildId: "build1", + workerIdentity: "worker1", + issuedSlots: &atomic.Int32{}, + logger: ilog.NewDefaultLogger(), + metrics: metrics.NopHandler, + sticky: false, + } + assert.False(t, info.IsSticky()) + }) +} + +// capturingSlotSupplier records the IsSticky() value from the most recent reservation call. +type capturingSlotSupplier struct { + lastReserveIsSticky atomic.Bool + lastTryReserveIsSticky atomic.Bool + reserveCalled atomic.Bool + tryReserveCalled atomic.Bool +} + +func (c *capturingSlotSupplier) ReserveSlot(_ context.Context, info SlotReservationInfo) (*SlotPermit, error) { + c.lastReserveIsSticky.Store(info.IsSticky()) + c.reserveCalled.Store(true) + return &SlotPermit{}, nil +} + +func (c *capturingSlotSupplier) TryReserveSlot(info SlotReservationInfo) *SlotPermit { + c.lastTryReserveIsSticky.Store(info.IsSticky()) + c.tryReserveCalled.Store(true) + return &SlotPermit{} +} + +func (c *capturingSlotSupplier) MarkSlotUsed(SlotMarkUsedInfo) {} +func (c *capturingSlotSupplier) ReleaseSlot(SlotReleaseInfo) {} +func (c *capturingSlotSupplier) MaxSlots() int { return 100 } + +func TestTrackingSlotSupplier_PropagatesIsSticky(t *testing.T) { + capturing := &capturingSlotSupplier{} + tss := newTrackingSlotSupplier(capturing, trackingSlotSupplierOptions{ + logger: ilog.NewDefaultLogger(), + metricsHandler: metrics.NopHandler, + }) + + t.Run("ReserveSlot with sticky true", func(t *testing.T) { + data := &slotReservationData{taskQueue: "test-queue", isSticky: true} + permit, err := tss.ReserveSlot(context.Background(), data) + require.NoError(t, err) + require.NotNil(t, permit) + assert.True(t, capturing.reserveCalled.Load()) + assert.True(t, capturing.lastReserveIsSticky.Load(), + "inner supplier should see IsSticky()=true") + }) + + t.Run("ReserveSlot with sticky false", func(t *testing.T) { + data := &slotReservationData{taskQueue: "test-queue", isSticky: false} + permit, err := tss.ReserveSlot(context.Background(), data) + require.NoError(t, err) + require.NotNil(t, permit) + assert.False(t, capturing.lastReserveIsSticky.Load(), + "inner supplier should see IsSticky()=false") + }) + + t.Run("TryReserveSlot with sticky true", func(t *testing.T) { + data := &slotReservationData{taskQueue: "test-queue", isSticky: true} + permit := tss.TryReserveSlot(data) + require.NotNil(t, permit) + assert.True(t, capturing.tryReserveCalled.Load()) + assert.True(t, capturing.lastTryReserveIsSticky.Load(), + "inner supplier should see IsSticky()=true on TryReserveSlot") + }) + + t.Run("TryReserveSlot with sticky false", func(t *testing.T) { + data := &slotReservationData{taskQueue: "test-queue", isSticky: false} + permit := tss.TryReserveSlot(data) + require.NotNil(t, permit) + assert.False(t, capturing.lastTryReserveIsSticky.Load(), + "inner supplier should see IsSticky()=false on TryReserveSlot") + }) +} + +func TestSlotReservationData_IsStickyDefault(t *testing.T) { + // By default, slotReservationData should have isSticky=false + data := slotReservationData{taskQueue: "test-queue"} + assert.False(t, data.isSticky, "default isSticky should be false") +} + +func TestDecideNextPollKind_StickyMode(t *testing.T) { + wtp := &workflowTaskPoller{mode: Sticky, stickyCacheSize: 10} + assert.True(t, wtp.decideNextPollKind(), "Sticky mode should always return true") + // Counters should not be touched for non-Mixed modes. + assert.Equal(t, 0, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) +} + +func TestDecideNextPollKind_NonStickyMode(t *testing.T) { + wtp := &workflowTaskPoller{mode: NonSticky, stickyCacheSize: 10} + assert.False(t, wtp.decideNextPollKind(), "NonSticky mode should always return false") + assert.Equal(t, 0, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) +} + +func TestDecideNextPollKind_MixedMode_Balancing(t *testing.T) { + wtp := &workflowTaskPoller{mode: Mixed, stickyCacheSize: 10} + + // First decision: both counts are 0, sticky preferred when equal. + isSticky := wtp.decideNextPollKind() + assert.True(t, isSticky, "first decision should be sticky (counts equal)") + assert.Equal(t, 1, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) + + // Second decision: sticky=1, regular=0, should pick regular to balance. + isSticky = wtp.decideNextPollKind() + assert.False(t, isSticky, "second decision should be regular (sticky > regular)") + assert.Equal(t, 1, wtp.pendingStickyPollCount) + assert.Equal(t, 1, wtp.pendingRegularPollCount) + + // Third decision: counts equal again, prefer sticky. + isSticky = wtp.decideNextPollKind() + assert.True(t, isSticky, "third decision should be sticky (counts equal)") + assert.Equal(t, 2, wtp.pendingStickyPollCount) + assert.Equal(t, 1, wtp.pendingRegularPollCount) +} + +func TestDecideNextPollKind_MixedMode_StickyBacklog(t *testing.T) { + wtp := &workflowTaskPoller{ + mode: Mixed, + stickyCacheSize: 10, + stickyBacklog: 5, + // Start with sticky > regular so that without backlog, regular would be preferred. + pendingStickyPollCount: 3, + pendingRegularPollCount: 1, + } + + // Even though sticky count > regular count, backlog forces sticky. + isSticky := wtp.decideNextPollKind() + assert.True(t, isSticky, "should choose sticky when backlog > 0") + assert.Equal(t, 4, wtp.pendingStickyPollCount) + assert.Equal(t, 1, wtp.pendingRegularPollCount) +} + +func TestDecideNextPollKind_MixedMode_DisabledStickyCache(t *testing.T) { + wtp := &workflowTaskPoller{mode: Mixed, stickyCacheSize: 0} + isSticky := wtp.decideNextPollKind() + assert.False(t, isSticky, "should return false when sticky cache is disabled") + assert.Equal(t, 0, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) +} + +func TestUndoPollDecision_Mixed(t *testing.T) { + wtp := &workflowTaskPoller{mode: Mixed, stickyCacheSize: 10} + + // Simulate decideNextPollKind then undo. + isSticky := wtp.decideNextPollKind() + assert.True(t, isSticky) + assert.Equal(t, 1, wtp.pendingStickyPollCount) + + wtp.undoPollDecision(isSticky) + assert.Equal(t, 0, wtp.pendingStickyPollCount, "counter should be back to 0 after undo") + + // Do the same for a regular decision. + wtp.pendingStickyPollCount = 2 + wtp.pendingRegularPollCount = 0 + isSticky = wtp.decideNextPollKind() + assert.False(t, isSticky, "should pick regular since sticky count is higher") + assert.Equal(t, 1, wtp.pendingRegularPollCount) + + wtp.undoPollDecision(isSticky) + assert.Equal(t, 0, wtp.pendingRegularPollCount, "regular counter should be back to 0 after undo") +} + +func TestUndoPollDecision_NonMixed_Noop(t *testing.T) { + wtp := &workflowTaskPoller{mode: Sticky, stickyCacheSize: 10} + // Should not panic or change anything. + wtp.undoPollDecision(true) + assert.Equal(t, 0, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) +} + +func TestGetNextPollRequest_WithHint_Sticky(t *testing.T) { + wtp := &workflowTaskPoller{ + mode: Mixed, + stickyCacheSize: 10, + taskQueueName: "test-queue", + stickyUUID: "sticky-uuid-123", + namespace: "test-ns", + identity: "test-identity", + } + + hint := &pollHint{isSticky: true} + req := wtp.getNextPollRequest(hint) + + assert.Equal(t, getWorkerTaskQueue("sticky-uuid-123"), req.TaskQueue.GetName()) + assert.Equal(t, enumspb.TASK_QUEUE_KIND_STICKY, req.TaskQueue.GetKind()) + assert.Equal(t, "test-queue", req.TaskQueue.GetNormalName()) + + // Counters should NOT be incremented by getNextPollRequest when a hint is provided. + assert.Equal(t, 0, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) +} + +func TestGetNextPollRequest_WithHint_Normal(t *testing.T) { + wtp := &workflowTaskPoller{ + mode: Mixed, + stickyCacheSize: 10, + taskQueueName: "test-queue", + stickyUUID: "sticky-uuid-123", + namespace: "test-ns", + identity: "test-identity", + } + + hint := &pollHint{isSticky: false} + req := wtp.getNextPollRequest(hint) + + assert.Equal(t, "test-queue", req.TaskQueue.GetName()) + assert.Equal(t, enumspb.TASK_QUEUE_KIND_NORMAL, req.TaskQueue.GetKind()) + + // Counters should NOT be incremented by getNextPollRequest when a hint is provided. + assert.Equal(t, 0, wtp.pendingStickyPollCount) + assert.Equal(t, 0, wtp.pendingRegularPollCount) +} + +func TestGetNextPollRequest_WithoutHint_FallsBackToInlineDecision(t *testing.T) { + wtp := &workflowTaskPoller{ + mode: Mixed, + stickyCacheSize: 10, + taskQueueName: "test-queue", + stickyUUID: "sticky-uuid-123", + namespace: "test-ns", + identity: "test-identity", + } + + // No hint: should use inline decision logic and increment counter. + req := wtp.getNextPollRequest(nil) + // First call with equal counts should pick sticky. + assert.Equal(t, enumspb.TASK_QUEUE_KIND_STICKY, req.TaskQueue.GetKind()) + assert.Equal(t, 1, wtp.pendingStickyPollCount, "counter should be incremented inline") +} + +func TestDecideAndPollHint_CounterLifecycle(t *testing.T) { + // Simulate the full lifecycle: decide -> (reserve succeeds) -> poll -> release. + wtp := &workflowTaskPoller{ + mode: Mixed, + stickyCacheSize: 10, + taskQueueName: "test-queue", + stickyUUID: "sticky-uuid-123", + namespace: "test-ns", + identity: "test-identity", + } + + // Step 1: Pre-decide (increments counter). + isSticky := wtp.decideNextPollKind() + assert.True(t, isSticky) + assert.Equal(t, 1, wtp.pendingStickyPollCount) + + // Step 2: getNextPollRequest with hint (does NOT increment counter). + hint := &pollHint{isSticky: isSticky} + req := wtp.getNextPollRequest(hint) + assert.Equal(t, enumspb.TASK_QUEUE_KIND_STICKY, req.TaskQueue.GetKind()) + assert.Equal(t, 1, wtp.pendingStickyPollCount, "counter unchanged by getNextPollRequest with hint") + + // Step 3: release (decrements counter, as happens after poll). + wtp.release(req.TaskQueue.GetKind()) + assert.Equal(t, 0, wtp.pendingStickyPollCount, "counter back to 0 after release") +} + +func TestDecideAndCancel_CounterLifecycle(t *testing.T) { + // Simulate: decide -> reservation fails -> undo. + wtp := &workflowTaskPoller{ + mode: Mixed, + stickyCacheSize: 10, + taskQueueName: "test-queue", + } + + isSticky := wtp.decideNextPollKind() + assert.True(t, isSticky) + assert.Equal(t, 1, wtp.pendingStickyPollCount) + + // Reservation failed, undo the decision. + wtp.undoPollDecision(isSticky) + assert.Equal(t, 0, wtp.pendingStickyPollCount, "counter should be back to 0 after undo") +}