Skip to content

Commit 4721b2c

Browse files
authored
Merge pull request #3 from atlanhq/feature/temporal-scaler-worker-slots-metric
feat: Add task slots metric to temporal scaler
2 parents 0c957fa + 563cd72 commit 4721b2c

File tree

3 files changed

+513
-16
lines changed

3 files changed

+513
-16
lines changed

pkg/scalers/temporal_scaler.go

Lines changed: 247 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,88 @@ import (
44
"context"
55
"crypto/tls"
66
"fmt"
7+
"io"
78
"log/slog"
9+
"net"
10+
"net/http"
11+
"strconv"
812
"strings"
13+
"sync"
914
"time"
1015

1116
"github.com/go-logr/logr"
17+
"github.com/prometheus/client_golang/prometheus"
18+
dto "github.com/prometheus/client_model/go"
19+
"github.com/prometheus/common/expfmt"
1220
workflowservice "go.temporal.io/api/workflowservice/v1"
1321
sdk "go.temporal.io/sdk/client"
1422
sdklog "go.temporal.io/sdk/log"
1523
"google.golang.org/grpc"
1624
"google.golang.org/grpc/metadata"
1725
v2 "k8s.io/api/autoscaling/v2"
26+
corev1 "k8s.io/api/core/v1"
1827
"k8s.io/metrics/pkg/apis/external_metrics"
28+
"sigs.k8s.io/controller-runtime/pkg/client"
29+
ctrlmetrics "sigs.k8s.io/controller-runtime/pkg/metrics"
1930

2031
"github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
2132
kedautil "github.com/kedacore/keda/v2/pkg/util"
2233
)
2334

35+
const (
36+
// scrapeLoopTimeout is the total time budget for scraping all worker pods.
37+
scrapeLoopTimeout = 12 * time.Second
38+
// slotsCacheTTL is how long a cached slots value remains valid after a
39+
// successful scrape before falling back to 0 on persistent failure.
40+
slotsCacheTTL = 180 * time.Second
41+
// maxMetricsResponseBytes limits the size of a single pod's /metrics response
42+
// to prevent OOM from misconfigured or malicious pods.
43+
maxMetricsResponseBytes = 10 * 1024 * 1024 // 10 MB
44+
)
45+
2446
var (
2547
temporalDefauleQueueTypes = []sdk.TaskQueueType{
2648
sdk.TaskQueueTypeActivity,
2749
sdk.TaskQueueTypeWorkflow,
2850
sdk.TaskQueueTypeNexus,
2951
}
52+
53+
// temporalSlotsScrapeErrors counts worker slot scrape failures by reason:
54+
// pod_scrape_error – a single pod's /metrics request failed
55+
// scrape_loop_timeout – 12s budget exceeded, used partial results
56+
// all_pods_failed_cache_hit – all pods failed; returned last cached value
57+
// all_pods_failed_cache_expired – all pods failed and cache expired; returned 0
58+
temporalSlotsScrapeErrors = prometheus.NewCounterVec(
59+
prometheus.CounterOpts{
60+
Namespace: "keda",
61+
Subsystem: "temporal_scaler",
62+
Name: "worker_slots_scrape_errors_total",
63+
Help: "Total number of temporal worker slot scrape failures. " +
64+
"Use reason label to distinguish pod-level errors from full-scrape failures.",
65+
},
66+
[]string{"namespace", "task_queue", "reason"},
67+
)
3068
)
3169

70+
func init() {
71+
ctrlmetrics.Registry.MustRegister(temporalSlotsScrapeErrors)
72+
}
73+
74+
type slotsCache struct {
75+
value int64
76+
timestamp time.Time
77+
}
78+
3279
type temporalScaler struct {
33-
metricType v2.MetricTargetType
34-
metadata *temporalMetadata
35-
tcl sdk.Client
36-
logger logr.Logger
80+
metricType v2.MetricTargetType
81+
metadata *temporalMetadata
82+
tcl sdk.Client
83+
kubeClient client.Client
84+
httpClient *http.Client
85+
logger logr.Logger
86+
podNamespace string
87+
slotsMu sync.Mutex
88+
lastSlots slotsCache
3789
}
3890

3991
type temporalMetadata struct {
@@ -48,6 +100,7 @@ type temporalMetadata struct {
48100
Unversioned bool `keda:"name=selectUnversioned, order=triggerMetadata, default=false"`
49101
IncludeRunningWorkflowCount bool `keda:"name=includeRunningWorkflowCount, order=triggerMetadata, default=true"`
50102
WorkflowTaskQueueForCount string `keda:"name=workflowTaskQueueForCount, order=triggerMetadata;resolvedEnv, optional"`
103+
WorkerMetricsPort int `keda:"name=workerMetricsPort, order=triggerMetadata, default=9464"`
51104
APIKey string `keda:"name=apiKey, order=authParams;resolvedEnv, optional"`
52105
MinConnectTimeout int `keda:"name=minConnectTimeout, order=triggerMetadata, default=5"`
53106

@@ -77,10 +130,14 @@ func (a *temporalMetadata) Validate() error {
77130
return fmt.Errorf("minConnectTimeout must be a positive number")
78131
}
79132

133+
if a.WorkerMetricsPort < 1 || a.WorkerMetricsPort > 65535 {
134+
return fmt.Errorf("workerMetricsPort must be between 1 and 65535")
135+
}
136+
80137
return nil
81138
}
82139

83-
func NewTemporalScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) {
140+
func NewTemporalScaler(ctx context.Context, kubeClient client.Client, config *scalersconfig.ScalerConfig) (Scaler, error) {
84141
logger := InitializeLogger(config, "temporal_scaler")
85142

86143
metricType, err := GetMetricTargetType(config)
@@ -99,10 +156,13 @@ func NewTemporalScaler(ctx context.Context, config *scalersconfig.ScalerConfig)
99156
}
100157

101158
return &temporalScaler{
102-
metricType: metricType,
103-
metadata: meta,
104-
tcl: c,
105-
logger: logger,
159+
metricType: metricType,
160+
metadata: meta,
161+
tcl: c,
162+
kubeClient: kubeClient,
163+
httpClient: kedautil.CreateHTTPClient(config.GlobalHTTPTimeout, false),
164+
logger: logger,
165+
podNamespace: config.ScalableObjectNamespace,
106166
}, nil
107167
}
108168

@@ -164,18 +224,25 @@ func (s *temporalScaler) getQueueSize(ctx context.Context) (int64, error) {
164224
}
165225

166226
backlog := getCombinedBacklogCount(resp)
227+
metric := backlog
167228

168-
if !s.metadata.IncludeRunningWorkflowCount {
169-
return backlog, nil
229+
if s.metadata.IncludeRunningWorkflowCount {
230+
runningCount, err := s.getRunningWorkflowCount(ctx)
231+
if err != nil {
232+
s.logger.V(1).Info("failed to get running workflow count, using backlog only", "error", err)
233+
} else {
234+
metric += runningCount
235+
}
170236
}
171237

172-
runningCount, err := s.getRunningWorkflowCount(ctx)
238+
usedSlots, err := s.getUsedWorkerSlots(ctx)
173239
if err != nil {
174-
s.logger.V(1).Info("failed to get running workflow count, using backlog only", "error", err)
175-
return backlog, nil
240+
s.logger.Info("failed to get worker slots metric, excluding from metric", "error", err)
241+
} else {
242+
metric += usedSlots
176243
}
177244

178-
return backlog + runningCount, nil
245+
return metric, nil
179246
}
180247

181248
// getRunningWorkflowCount returns the approximate number of running workflow executions
@@ -201,6 +268,171 @@ func (s *temporalScaler) getRunningWorkflowCount(ctx context.Context) (int64, er
201268
return resp.GetCount(), nil
202269
}
203270

271+
// getUsedWorkerSlots discovers worker pods in the ScaledObject's namespace and
272+
// scrapes their Prometheus metrics endpoint to sum temporal_worker_task_slots_used
273+
// for worker types matching the configured queueTypes. This prevents premature
274+
// scale-down when workers are actively executing tasks but the task queue backlog
275+
// is empty.
276+
//
277+
// On transient failures (all pod scrapes fail), it returns the last known good
278+
// value if within the cache TTL. A total timeout budget bounds the scrape loop
279+
// so that slow/unreachable pods don't block the KEDA polling cycle.
280+
func (s *temporalScaler) getUsedWorkerSlots(ctx context.Context) (int64, error) {
281+
if s.kubeClient == nil || s.httpClient == nil {
282+
return 0, fmt.Errorf("kubernetes client or http client not configured")
283+
}
284+
285+
podList := &corev1.PodList{}
286+
labelSelector := client.MatchingLabels{"app.kubernetes.io/component": "worker"}
287+
if err := s.kubeClient.List(ctx, podList, client.InNamespace(s.podNamespace), labelSelector); err != nil {
288+
return 0, fmt.Errorf("failed to list worker pods in namespace %s: %w", s.podNamespace, err)
289+
}
290+
291+
if len(podList.Items) == 0 {
292+
return 0, nil
293+
}
294+
295+
// Apply a timeout budget for the entire scrape loop.
296+
scrapeCtx, cancel := context.WithTimeout(ctx, scrapeLoopTimeout)
297+
defer cancel()
298+
299+
var totalUsedSlots int64
300+
var scrapedCount, attemptedCount int
301+
for i := range podList.Items {
302+
pod := &podList.Items[i]
303+
if pod.Status.Phase != corev1.PodRunning || pod.Status.PodIP == "" || !isPodReady(pod) {
304+
continue
305+
}
306+
307+
// Stop scraping if we've exceeded the timeout budget.
308+
if scrapeCtx.Err() != nil {
309+
s.logger.Info("scrape loop timeout reached, using partial results",
310+
"scraped", scrapedCount, "remaining", len(podList.Items)-i)
311+
temporalSlotsScrapeErrors.WithLabelValues(s.podNamespace, s.metadata.TaskQueue, "scrape_loop_timeout").Inc()
312+
break
313+
}
314+
315+
attemptedCount++
316+
slots, err := s.scrapeWorkerSlots(scrapeCtx, pod.Status.PodIP)
317+
if err != nil {
318+
s.logger.Info("failed to scrape worker pod metrics, skipping",
319+
"pod", pod.Name, "ip", pod.Status.PodIP, "error", err)
320+
temporalSlotsScrapeErrors.WithLabelValues(s.podNamespace, s.metadata.TaskQueue, "pod_scrape_error").Inc()
321+
continue
322+
}
323+
totalUsedSlots += slots
324+
scrapedCount++
325+
}
326+
327+
s.logger.V(1).Info("worker slots metric",
328+
"namespace", s.podNamespace, "totalUsedSlots", totalUsedSlots,
329+
"podCount", len(podList.Items), "scrapedCount", scrapedCount)
330+
331+
// No ready pods to scrape (e.g. all pods still starting up) — return 0.
332+
if attemptedCount == 0 {
333+
return 0, nil
334+
}
335+
336+
// All attempted scrapes failed — fall back to cached value within TTL.
337+
if scrapedCount == 0 {
338+
s.slotsMu.Lock()
339+
cached := s.lastSlots
340+
s.slotsMu.Unlock()
341+
if time.Since(cached.timestamp) <= slotsCacheTTL {
342+
s.logger.Info("all scrapes failed, using cached slots value",
343+
"cachedValue", cached.value, "cacheAge", time.Since(cached.timestamp).String())
344+
temporalSlotsScrapeErrors.WithLabelValues(s.podNamespace, s.metadata.TaskQueue, "all_pods_failed_cache_hit").Inc()
345+
return cached.value, nil
346+
}
347+
s.logger.Info("all scrapes failed and cache expired, returning 0")
348+
temporalSlotsScrapeErrors.WithLabelValues(s.podNamespace, s.metadata.TaskQueue, "all_pods_failed_cache_expired").Inc()
349+
return 0, nil
350+
}
351+
352+
// Update cache with the fresh value.
353+
s.slotsMu.Lock()
354+
s.lastSlots = slotsCache{value: totalUsedSlots, timestamp: time.Now()}
355+
s.slotsMu.Unlock()
356+
357+
return totalUsedSlots, nil
358+
}
359+
360+
// scrapeWorkerSlots fetches Prometheus metrics from a single worker pod and returns
361+
// the sum of temporal_worker_task_slots_used for worker types matching the
362+
// configured queueTypes and task queue.
363+
func (s *temporalScaler) scrapeWorkerSlots(ctx context.Context, podIP string) (int64, error) {
364+
hostPort := net.JoinHostPort(podIP, strconv.Itoa(s.metadata.WorkerMetricsPort))
365+
url := fmt.Sprintf("http://%s/metrics", hostPort)
366+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
367+
if err != nil {
368+
return 0, fmt.Errorf("create request: %w", err)
369+
}
370+
371+
resp, err := s.httpClient.Do(req)
372+
if err != nil {
373+
return 0, fmt.Errorf("scrape %s: %w", url, err)
374+
}
375+
defer resp.Body.Close()
376+
377+
if resp.StatusCode != http.StatusOK {
378+
return 0, fmt.Errorf("scrape %s returned status %d", url, resp.StatusCode)
379+
}
380+
381+
limitedBody := io.LimitReader(resp.Body, maxMetricsResponseBytes)
382+
activityOnly := map[string]bool{"ActivityWorker": true}
383+
return parseUsedSlots(limitedBody, s.metadata.TaskQueue, activityOnly)
384+
}
385+
386+
// parseUsedSlots parses Prometheus text format and extracts the sum of
387+
// temporal_worker_task_slots_used for the given worker types matching the task queue.
388+
func parseUsedSlots(r io.Reader, taskQueue string, workerTypes map[string]bool) (int64, error) {
389+
var parser expfmt.TextParser
390+
families, err := parser.TextToMetricFamilies(r)
391+
if err != nil {
392+
return 0, fmt.Errorf("parse prometheus metrics: %w", err)
393+
}
394+
395+
family, ok := families["temporal_worker_task_slots_used"]
396+
if !ok {
397+
return 0, nil
398+
}
399+
400+
var total int64
401+
for _, m := range family.GetMetric() {
402+
if matchesWorkerSlot(m, taskQueue, workerTypes) {
403+
total += int64(m.GetGauge().GetValue())
404+
}
405+
}
406+
return total, nil
407+
}
408+
409+
// matchesWorkerSlot returns true if the metric's worker_type is in the allowed set
410+
// and (if taskQueue is non-empty) task_queue matches the configured queue.
411+
func matchesWorkerSlot(m *dto.Metric, taskQueue string, workerTypes map[string]bool) bool {
412+
var typeMatches bool
413+
queueMatches := taskQueue == ""
414+
for _, lp := range m.GetLabel() {
415+
switch lp.GetName() {
416+
case "worker_type":
417+
typeMatches = workerTypes[lp.GetValue()]
418+
case "task_queue":
419+
if !queueMatches {
420+
queueMatches = lp.GetValue() == taskQueue
421+
}
422+
}
423+
}
424+
return typeMatches && queueMatches
425+
}
426+
427+
func isPodReady(pod *corev1.Pod) bool {
428+
for _, c := range pod.Status.Conditions {
429+
if c.Type == corev1.PodReady {
430+
return c.Status == corev1.ConditionTrue
431+
}
432+
}
433+
return false
434+
}
435+
204436
func getQueueTypes(queueTypes []string) []sdk.TaskQueueType {
205437
var taskQueueTypes []sdk.TaskQueueType
206438
for _, t := range queueTypes {

0 commit comments

Comments
 (0)