@@ -4,36 +4,88 @@ import (
44 "context"
55 "crypto/tls"
66 "fmt"
7+ "io"
78 "log/slog"
9+ "net"
10+ "net/http"
11+ "strconv"
812 "strings"
13+ "sync"
914 "time"
1015
1116 "github.com/go-logr/logr"
17+ "github.com/prometheus/client_golang/prometheus"
18+ dto "github.com/prometheus/client_model/go"
19+ "github.com/prometheus/common/expfmt"
1220 workflowservice "go.temporal.io/api/workflowservice/v1"
1321 sdk "go.temporal.io/sdk/client"
1422 sdklog "go.temporal.io/sdk/log"
1523 "google.golang.org/grpc"
1624 "google.golang.org/grpc/metadata"
1725 v2 "k8s.io/api/autoscaling/v2"
26+ corev1 "k8s.io/api/core/v1"
1827 "k8s.io/metrics/pkg/apis/external_metrics"
28+ "sigs.k8s.io/controller-runtime/pkg/client"
29+ ctrlmetrics "sigs.k8s.io/controller-runtime/pkg/metrics"
1930
2031 "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
2132 kedautil "github.com/kedacore/keda/v2/pkg/util"
2233)
2334
35+ const (
36+ // scrapeLoopTimeout is the total time budget for scraping all worker pods.
37+ scrapeLoopTimeout = 12 * time .Second
38+ // slotsCacheTTL is how long a cached slots value remains valid after a
39+ // successful scrape before falling back to 0 on persistent failure.
40+ slotsCacheTTL = 180 * time .Second
41+ // maxMetricsResponseBytes limits the size of a single pod's /metrics response
42+ // to prevent OOM from misconfigured or malicious pods.
43+ maxMetricsResponseBytes = 10 * 1024 * 1024 // 10 MB
44+ )
45+
2446var (
2547 temporalDefauleQueueTypes = []sdk.TaskQueueType {
2648 sdk .TaskQueueTypeActivity ,
2749 sdk .TaskQueueTypeWorkflow ,
2850 sdk .TaskQueueTypeNexus ,
2951 }
52+
53+ // temporalSlotsScrapeErrors counts worker slot scrape failures by reason:
54+ // pod_scrape_error – a single pod's /metrics request failed
55+ // scrape_loop_timeout – 12s budget exceeded, used partial results
56+ // all_pods_failed_cache_hit – all pods failed; returned last cached value
57+ // all_pods_failed_cache_expired – all pods failed and cache expired; returned 0
58+ temporalSlotsScrapeErrors = prometheus .NewCounterVec (
59+ prometheus.CounterOpts {
60+ Namespace : "keda" ,
61+ Subsystem : "temporal_scaler" ,
62+ Name : "worker_slots_scrape_errors_total" ,
63+ Help : "Total number of temporal worker slot scrape failures. " +
64+ "Use reason label to distinguish pod-level errors from full-scrape failures." ,
65+ },
66+ []string {"namespace" , "task_queue" , "reason" },
67+ )
3068)
3169
70+ func init () {
71+ ctrlmetrics .Registry .MustRegister (temporalSlotsScrapeErrors )
72+ }
73+
74+ type slotsCache struct {
75+ value int64
76+ timestamp time.Time
77+ }
78+
3279type temporalScaler struct {
33- metricType v2.MetricTargetType
34- metadata * temporalMetadata
35- tcl sdk.Client
36- logger logr.Logger
80+ metricType v2.MetricTargetType
81+ metadata * temporalMetadata
82+ tcl sdk.Client
83+ kubeClient client.Client
84+ httpClient * http.Client
85+ logger logr.Logger
86+ podNamespace string
87+ slotsMu sync.Mutex
88+ lastSlots slotsCache
3789}
3890
3991type temporalMetadata struct {
@@ -48,6 +100,7 @@ type temporalMetadata struct {
48100 Unversioned bool `keda:"name=selectUnversioned, order=triggerMetadata, default=false"`
49101 IncludeRunningWorkflowCount bool `keda:"name=includeRunningWorkflowCount, order=triggerMetadata, default=true"`
50102 WorkflowTaskQueueForCount string `keda:"name=workflowTaskQueueForCount, order=triggerMetadata;resolvedEnv, optional"`
103+ WorkerMetricsPort int `keda:"name=workerMetricsPort, order=triggerMetadata, default=9464"`
51104 APIKey string `keda:"name=apiKey, order=authParams;resolvedEnv, optional"`
52105 MinConnectTimeout int `keda:"name=minConnectTimeout, order=triggerMetadata, default=5"`
53106
@@ -77,10 +130,14 @@ func (a *temporalMetadata) Validate() error {
77130 return fmt .Errorf ("minConnectTimeout must be a positive number" )
78131 }
79132
133+ if a .WorkerMetricsPort < 1 || a .WorkerMetricsPort > 65535 {
134+ return fmt .Errorf ("workerMetricsPort must be between 1 and 65535" )
135+ }
136+
80137 return nil
81138}
82139
83- func NewTemporalScaler (ctx context.Context , config * scalersconfig.ScalerConfig ) (Scaler , error ) {
140+ func NewTemporalScaler (ctx context.Context , kubeClient client. Client , config * scalersconfig.ScalerConfig ) (Scaler , error ) {
84141 logger := InitializeLogger (config , "temporal_scaler" )
85142
86143 metricType , err := GetMetricTargetType (config )
@@ -99,10 +156,13 @@ func NewTemporalScaler(ctx context.Context, config *scalersconfig.ScalerConfig)
99156 }
100157
101158 return & temporalScaler {
102- metricType : metricType ,
103- metadata : meta ,
104- tcl : c ,
105- logger : logger ,
159+ metricType : metricType ,
160+ metadata : meta ,
161+ tcl : c ,
162+ kubeClient : kubeClient ,
163+ httpClient : kedautil .CreateHTTPClient (config .GlobalHTTPTimeout , false ),
164+ logger : logger ,
165+ podNamespace : config .ScalableObjectNamespace ,
106166 }, nil
107167}
108168
@@ -164,18 +224,25 @@ func (s *temporalScaler) getQueueSize(ctx context.Context) (int64, error) {
164224 }
165225
166226 backlog := getCombinedBacklogCount (resp )
227+ metric := backlog
167228
168- if ! s .metadata .IncludeRunningWorkflowCount {
169- return backlog , nil
229+ if s .metadata .IncludeRunningWorkflowCount {
230+ runningCount , err := s .getRunningWorkflowCount (ctx )
231+ if err != nil {
232+ s .logger .V (1 ).Info ("failed to get running workflow count, using backlog only" , "error" , err )
233+ } else {
234+ metric += runningCount
235+ }
170236 }
171237
172- runningCount , err := s .getRunningWorkflowCount (ctx )
238+ usedSlots , err := s .getUsedWorkerSlots (ctx )
173239 if err != nil {
174- s .logger .V (1 ).Info ("failed to get running workflow count, using backlog only" , "error" , err )
175- return backlog , nil
240+ s .logger .Info ("failed to get worker slots metric, excluding from metric" , "error" , err )
241+ } else {
242+ metric += usedSlots
176243 }
177244
178- return backlog + runningCount , nil
245+ return metric , nil
179246}
180247
181248// getRunningWorkflowCount returns the approximate number of running workflow executions
@@ -201,6 +268,171 @@ func (s *temporalScaler) getRunningWorkflowCount(ctx context.Context) (int64, er
201268 return resp .GetCount (), nil
202269}
203270
271+ // getUsedWorkerSlots discovers worker pods in the ScaledObject's namespace and
272+ // scrapes their Prometheus metrics endpoint to sum temporal_worker_task_slots_used
273+ // for worker types matching the configured queueTypes. This prevents premature
274+ // scale-down when workers are actively executing tasks but the task queue backlog
275+ // is empty.
276+ //
277+ // On transient failures (all pod scrapes fail), it returns the last known good
278+ // value if within the cache TTL. A total timeout budget bounds the scrape loop
279+ // so that slow/unreachable pods don't block the KEDA polling cycle.
280+ func (s * temporalScaler ) getUsedWorkerSlots (ctx context.Context ) (int64 , error ) {
281+ if s .kubeClient == nil || s .httpClient == nil {
282+ return 0 , fmt .Errorf ("kubernetes client or http client not configured" )
283+ }
284+
285+ podList := & corev1.PodList {}
286+ labelSelector := client.MatchingLabels {"app.kubernetes.io/component" : "worker" }
287+ if err := s .kubeClient .List (ctx , podList , client .InNamespace (s .podNamespace ), labelSelector ); err != nil {
288+ return 0 , fmt .Errorf ("failed to list worker pods in namespace %s: %w" , s .podNamespace , err )
289+ }
290+
291+ if len (podList .Items ) == 0 {
292+ return 0 , nil
293+ }
294+
295+ // Apply a timeout budget for the entire scrape loop.
296+ scrapeCtx , cancel := context .WithTimeout (ctx , scrapeLoopTimeout )
297+ defer cancel ()
298+
299+ var totalUsedSlots int64
300+ var scrapedCount , attemptedCount int
301+ for i := range podList .Items {
302+ pod := & podList .Items [i ]
303+ if pod .Status .Phase != corev1 .PodRunning || pod .Status .PodIP == "" || ! isPodReady (pod ) {
304+ continue
305+ }
306+
307+ // Stop scraping if we've exceeded the timeout budget.
308+ if scrapeCtx .Err () != nil {
309+ s .logger .Info ("scrape loop timeout reached, using partial results" ,
310+ "scraped" , scrapedCount , "remaining" , len (podList .Items )- i )
311+ temporalSlotsScrapeErrors .WithLabelValues (s .podNamespace , s .metadata .TaskQueue , "scrape_loop_timeout" ).Inc ()
312+ break
313+ }
314+
315+ attemptedCount ++
316+ slots , err := s .scrapeWorkerSlots (scrapeCtx , pod .Status .PodIP )
317+ if err != nil {
318+ s .logger .Info ("failed to scrape worker pod metrics, skipping" ,
319+ "pod" , pod .Name , "ip" , pod .Status .PodIP , "error" , err )
320+ temporalSlotsScrapeErrors .WithLabelValues (s .podNamespace , s .metadata .TaskQueue , "pod_scrape_error" ).Inc ()
321+ continue
322+ }
323+ totalUsedSlots += slots
324+ scrapedCount ++
325+ }
326+
327+ s .logger .V (1 ).Info ("worker slots metric" ,
328+ "namespace" , s .podNamespace , "totalUsedSlots" , totalUsedSlots ,
329+ "podCount" , len (podList .Items ), "scrapedCount" , scrapedCount )
330+
331+ // No ready pods to scrape (e.g. all pods still starting up) — return 0.
332+ if attemptedCount == 0 {
333+ return 0 , nil
334+ }
335+
336+ // All attempted scrapes failed — fall back to cached value within TTL.
337+ if scrapedCount == 0 {
338+ s .slotsMu .Lock ()
339+ cached := s .lastSlots
340+ s .slotsMu .Unlock ()
341+ if time .Since (cached .timestamp ) <= slotsCacheTTL {
342+ s .logger .Info ("all scrapes failed, using cached slots value" ,
343+ "cachedValue" , cached .value , "cacheAge" , time .Since (cached .timestamp ).String ())
344+ temporalSlotsScrapeErrors .WithLabelValues (s .podNamespace , s .metadata .TaskQueue , "all_pods_failed_cache_hit" ).Inc ()
345+ return cached .value , nil
346+ }
347+ s .logger .Info ("all scrapes failed and cache expired, returning 0" )
348+ temporalSlotsScrapeErrors .WithLabelValues (s .podNamespace , s .metadata .TaskQueue , "all_pods_failed_cache_expired" ).Inc ()
349+ return 0 , nil
350+ }
351+
352+ // Update cache with the fresh value.
353+ s .slotsMu .Lock ()
354+ s .lastSlots = slotsCache {value : totalUsedSlots , timestamp : time .Now ()}
355+ s .slotsMu .Unlock ()
356+
357+ return totalUsedSlots , nil
358+ }
359+
360+ // scrapeWorkerSlots fetches Prometheus metrics from a single worker pod and returns
361+ // the sum of temporal_worker_task_slots_used for worker types matching the
362+ // configured queueTypes and task queue.
363+ func (s * temporalScaler ) scrapeWorkerSlots (ctx context.Context , podIP string ) (int64 , error ) {
364+ hostPort := net .JoinHostPort (podIP , strconv .Itoa (s .metadata .WorkerMetricsPort ))
365+ url := fmt .Sprintf ("http://%s/metrics" , hostPort )
366+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , url , nil )
367+ if err != nil {
368+ return 0 , fmt .Errorf ("create request: %w" , err )
369+ }
370+
371+ resp , err := s .httpClient .Do (req )
372+ if err != nil {
373+ return 0 , fmt .Errorf ("scrape %s: %w" , url , err )
374+ }
375+ defer resp .Body .Close ()
376+
377+ if resp .StatusCode != http .StatusOK {
378+ return 0 , fmt .Errorf ("scrape %s returned status %d" , url , resp .StatusCode )
379+ }
380+
381+ limitedBody := io .LimitReader (resp .Body , maxMetricsResponseBytes )
382+ activityOnly := map [string ]bool {"ActivityWorker" : true }
383+ return parseUsedSlots (limitedBody , s .metadata .TaskQueue , activityOnly )
384+ }
385+
386+ // parseUsedSlots parses Prometheus text format and extracts the sum of
387+ // temporal_worker_task_slots_used for the given worker types matching the task queue.
388+ func parseUsedSlots (r io.Reader , taskQueue string , workerTypes map [string ]bool ) (int64 , error ) {
389+ var parser expfmt.TextParser
390+ families , err := parser .TextToMetricFamilies (r )
391+ if err != nil {
392+ return 0 , fmt .Errorf ("parse prometheus metrics: %w" , err )
393+ }
394+
395+ family , ok := families ["temporal_worker_task_slots_used" ]
396+ if ! ok {
397+ return 0 , nil
398+ }
399+
400+ var total int64
401+ for _ , m := range family .GetMetric () {
402+ if matchesWorkerSlot (m , taskQueue , workerTypes ) {
403+ total += int64 (m .GetGauge ().GetValue ())
404+ }
405+ }
406+ return total , nil
407+ }
408+
409+ // matchesWorkerSlot returns true if the metric's worker_type is in the allowed set
410+ // and (if taskQueue is non-empty) task_queue matches the configured queue.
411+ func matchesWorkerSlot (m * dto.Metric , taskQueue string , workerTypes map [string ]bool ) bool {
412+ var typeMatches bool
413+ queueMatches := taskQueue == ""
414+ for _ , lp := range m .GetLabel () {
415+ switch lp .GetName () {
416+ case "worker_type" :
417+ typeMatches = workerTypes [lp .GetValue ()]
418+ case "task_queue" :
419+ if ! queueMatches {
420+ queueMatches = lp .GetValue () == taskQueue
421+ }
422+ }
423+ }
424+ return typeMatches && queueMatches
425+ }
426+
427+ func isPodReady (pod * corev1.Pod ) bool {
428+ for _ , c := range pod .Status .Conditions {
429+ if c .Type == corev1 .PodReady {
430+ return c .Status == corev1 .ConditionTrue
431+ }
432+ }
433+ return false
434+ }
435+
204436func getQueueTypes (queueTypes []string ) []sdk.TaskQueueType {
205437 var taskQueueTypes []sdk.TaskQueueType
206438 for _ , t := range queueTypes {
0 commit comments