@@ -10,15 +10,18 @@ import (
1010 neturl "net/url"
1111 "strconv"
1212 "strings"
13+ "sync"
1314
1415 "github.com/go-logr/logr"
1516 "github.com/prometheus/common/expfmt"
1617 "github.com/prometheus/prometheus/promql/parser"
1718 "github.com/tidwall/gjson"
1819 "gopkg.in/yaml.v3"
1920 v2 "k8s.io/api/autoscaling/v2"
21+ corev1 "k8s.io/api/core/v1"
2022 "k8s.io/apimachinery/pkg/api/resource"
2123 "k8s.io/metrics/pkg/apis/external_metrics"
24+ "sigs.k8s.io/controller-runtime/pkg/client"
2225
2326 "github.com/kedacore/keda/v2/pkg/scalers/authentication"
2427 "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
@@ -30,15 +33,18 @@ type metricsAPIScaler struct {
3033 metadata * metricsAPIScalerMetadata
3134 httpClient * http.Client
3235 logger logr.Logger
36+ kubeClient client.Client
3337}
3438
3539type metricsAPIScalerMetadata struct {
36- targetValue float64
37- activationTargetValue float64
38- url string
39- format APIFormat
40- valueLocation string
41- unsafeSsl bool
40+ targetValue float64
41+ activationTargetValue float64
42+ url string
43+ format APIFormat
44+ aggregationType AggregationType
45+ valueLocation string
46+ unsafeSsl bool
47+ aggregateFromKubeServiceEndpoints bool
4248
4349 // apiKeyAuth
4450 enableAPIKeyAuth bool
@@ -71,6 +77,8 @@ const (
7177 valueLocationWrongErrorMsg = "valueLocation must point to value of type number or a string representing a Quantity got: '%s'"
7278)
7379
80+ const secureHTTPScheme = "https"
81+
7482type APIFormat string
7583
7684// Options for APIFormat:
9098 }
9199)
92100
101+ type AggregationType string
102+
103+ // Options for APIFormat:
104+ const (
105+ AverageAggregationType AggregationType = "average"
106+ SumAggregationType AggregationType = "sum"
107+ MaxAggregationType AggregationType = "max"
108+ MinAggregationType AggregationType = "min"
109+ )
110+
111+ var (
112+ supportedAggregationTypes = []AggregationType {
113+ AverageAggregationType ,
114+ SumAggregationType ,
115+ MaxAggregationType ,
116+ MinAggregationType ,
117+ }
118+ )
119+
93120// NewMetricsAPIScaler creates a new HTTP scaler
94- func NewMetricsAPIScaler (config * scalersconfig.ScalerConfig ) (Scaler , error ) {
121+ func NewMetricsAPIScaler (config * scalersconfig.ScalerConfig , kubeClient client. Client ) (Scaler , error ) {
95122 metricType , err := GetMetricTargetType (config )
96123 if err != nil {
97124 return nil , fmt .Errorf ("error getting scaler metric type: %w" , err )
@@ -116,6 +143,7 @@ func NewMetricsAPIScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
116143 metricType : metricType ,
117144 metadata : meta ,
118145 httpClient : httpClient ,
146+ kubeClient : kubeClient ,
119147 logger : InitializeLogger (config , "metrics_api_scaler" ),
120148 }, nil
121149}
@@ -133,6 +161,15 @@ func parseMetricsAPIMetadata(config *scalersconfig.ScalerConfig) (*metricsAPISca
133161 meta .unsafeSsl = unsafeSsl
134162 }
135163
164+ meta .aggregateFromKubeServiceEndpoints = false
165+ if val , ok := config .TriggerMetadata ["aggregateFromKubeServiceEndpoints" ]; ok {
166+ aggregateFromKubeServiceEndpoints , err := strconv .ParseBool (val )
167+ if err != nil {
168+ return nil , fmt .Errorf ("error parsing aggregateFromKubeServiceEndpoints: %w" , err )
169+ }
170+ meta .aggregateFromKubeServiceEndpoints = aggregateFromKubeServiceEndpoints
171+ }
172+
136173 if val , ok := config .TriggerMetadata ["targetValue" ]; ok {
137174 targetValue , err := strconv .ParseFloat (val , 64 )
138175 if err != nil {
@@ -172,6 +209,16 @@ func parseMetricsAPIMetadata(config *scalersconfig.ScalerConfig) (*metricsAPISca
172209 meta .format = JSONFormat
173210 }
174211
212+ if val , ok := config .TriggerMetadata ["aggregationType" ]; ok {
213+ meta .aggregationType = AggregationType (strings .TrimSpace (val ))
214+ if ! kedautil .Contains (supportedAggregationTypes , meta .aggregationType ) {
215+ return nil , fmt .Errorf ("aggregation type %s not supported" , meta .aggregationType )
216+ }
217+ } else {
218+ // default aggregation type is average
219+ meta .aggregationType = AverageAggregationType
220+ }
221+
175222 if val , ok := config .TriggerMetadata ["valueLocation" ]; ok {
176223 meta .valueLocation = val
177224 } else {
@@ -408,8 +455,147 @@ func getValueFromYAMLResponse(body []byte, valueLocation string) (float64, error
408455 }
409456}
410457
458+ func (s * metricsAPIScaler ) getEndpointsUrlsFromServiceURL (ctx context.Context , serviceURL string ) (endpointUrls []string , err error ) {
459+ // parse service name from s.meta.url
460+ url , err := neturl .Parse (serviceURL )
461+ if err != nil {
462+ s .logger .Error (err , "Failed parsing url for metrics API" )
463+ } else {
464+ splittedHost := strings .Split (url .Host , "." )
465+ if len (splittedHost ) < 2 {
466+ return nil , fmt .Errorf ("invalid hostname %s : expected at least 2 elements, first being service name and second being the namespace" , url .Host )
467+ }
468+ serviceName := splittedHost [0 ]
469+ namespace := splittedHost [1 ]
470+ podPort := url .Port ()
471+ // infer port from service scheme when not set explicitly
472+ if podPort == "" {
473+ if url .Scheme == secureHTTPScheme {
474+ podPort = "443"
475+ } else {
476+ podPort = "80"
477+ }
478+ }
479+ // get service serviceEndpoints
480+ serviceEndpoints := & corev1.Endpoints {}
481+
482+ err := s .kubeClient .Get (ctx , client.ObjectKey {
483+ Namespace : namespace ,
484+ Name : serviceName ,
485+ }, serviceEndpoints )
486+ if err != nil {
487+ return nil , err
488+ }
489+
490+ for _ , subset := range serviceEndpoints .Subsets {
491+ foundPort := ""
492+ for _ , port := range subset .Ports {
493+ if strconv .Itoa (int (port .Port )) == podPort {
494+ foundPort = fmt .Sprintf (":%d" , port .Port )
495+ break
496+ }
497+ }
498+ if foundPort == "" {
499+ s .logger .Info (fmt .Sprintf ("Warning : could not find port %s in endpoint slice for service %s.%s definition. Will infer port from %s scheme" , podPort , serviceName , namespace , url .Scheme ))
500+ }
501+ for _ , address := range subset .Addresses {
502+ if address .NodeName != nil {
503+ endpointUrls = append (endpointUrls , fmt .Sprintf ("%s://%s%s/%s" , url .Scheme , address .IP , foundPort , url .Path ))
504+ }
505+ }
506+ }
507+ }
508+ return endpointUrls , err
509+ }
510+
411511func (s * metricsAPIScaler ) getMetricValue (ctx context.Context ) (float64 , error ) {
412- request , err := getMetricAPIServerRequest (ctx , s .metadata )
512+ // if we wish to aggregate metric from a kubernetes service then we need to query each endpoint behind the service
513+ if s .metadata .aggregateFromKubeServiceEndpoints {
514+ endpointsUrls , err := s .getEndpointsUrlsFromServiceURL (ctx , s .metadata .url )
515+ if err != nil {
516+ s .logger .Error (err , "Failed to get kubernetes endpoints urls from configured service URL. Falling back to querying url configured in metadata" )
517+ } else {
518+ if len (endpointsUrls ) == 0 {
519+ s .logger .Error (err , "No endpoints URLs were given for the service name. Falling back to querying url configured in metadata" )
520+ } else {
521+ aggregatedMetric , err := s .aggregateMetricsFromMultipleEndpoints (ctx , endpointsUrls )
522+ if err != nil {
523+ s .logger .Error (err , "No aggregated metrics could be computed from service endpoints. Falling back to querying url configured in metadata" )
524+ } else {
525+ return aggregatedMetric , err
526+ }
527+ }
528+ }
529+ }
530+ // get single/unaggregated metric
531+ metric , err := s .getMetricValueFromURL (ctx , nil )
532+ if err == nil {
533+ s .logger .V (1 ).Info (fmt .Sprintf ("fetched single metric from metrics API url : %s. Value is %v\n " , s .metadata .url , metric ))
534+ }
535+ return metric , err
536+ }
537+
538+ func (s * metricsAPIScaler ) aggregateMetricsFromMultipleEndpoints (ctx context.Context , endpointsUrls []string ) (float64 , error ) {
539+ // call s.getMetricValueFromURL() for each endpointsUrls in parallel goroutines (maximum 5 at a time) and sum them up
540+ const maxGoroutines = 5
541+ var mu sync.Mutex
542+ var wg sync.WaitGroup
543+ sem := make (chan struct {}, maxGoroutines )
544+ expectedNbMetrics := len (endpointsUrls )
545+ nbErrors := 0
546+ var err error
547+ var firstMetricEncountered bool
548+ var aggregation float64
549+ for _ , endpointURL := range endpointsUrls {
550+ wg .Add (1 )
551+ sem <- struct {}{} // Acquire semaphore slot
552+ go func (url string ) {
553+ defer wg .Done ()
554+ metric , err := s .getMetricValueFromURL (ctx , & endpointURL )
555+
556+ if err != nil {
557+ s .logger .Info (fmt .Sprintf ("Error fetching metric for %s: %v\n " , url , err ))
558+ // we will ignore metric for computing aggregation when encountering error : decrease expectedNbMetrics
559+ mu .Lock ()
560+ expectedNbMetrics --
561+ nbErrors ++
562+ mu .Unlock ()
563+ } else {
564+ mu .Lock ()
565+ switch s .metadata .aggregationType {
566+ case MinAggregationType :
567+ if ! firstMetricEncountered || metric < aggregation {
568+ firstMetricEncountered = true
569+ aggregation = metric
570+ }
571+ case MaxAggregationType :
572+ if ! firstMetricEncountered || metric > aggregation {
573+ firstMetricEncountered = true
574+ aggregation = metric
575+ }
576+ default :
577+ // sum metrics if we are not looking for min or max value
578+ aggregation += metric
579+ }
580+ mu .Unlock ()
581+ }
582+ <- sem // Release semaphore slot
583+ }(endpointURL )
584+ }
585+
586+ wg .Wait ()
587+ if nbErrors > 0 && nbErrors == len (endpointsUrls ) {
588+ err = fmt .Errorf ("could not get any metric successfully from the %d provided endpoints" , len (endpointsUrls ))
589+ }
590+ if s .metadata .aggregationType == AverageAggregationType {
591+ aggregation /= float64 (expectedNbMetrics )
592+ }
593+ s .logger .V (1 ).Info (fmt .Sprintf ("fetched %d metrics out of %d endpoints from kubernetes service : %s is %v\n " , expectedNbMetrics , len (endpointsUrls ), s .metadata .aggregationType , aggregation ))
594+ return aggregation , err
595+ }
596+
597+ func (s * metricsAPIScaler ) getMetricValueFromURL (ctx context.Context , url * string ) (float64 , error ) {
598+ request , err := getMetricAPIServerRequest (ctx , s .metadata , url )
413599 if err != nil {
414600 return 0 , err
415601 }
@@ -470,14 +656,17 @@ func (s *metricsAPIScaler) GetMetricsAndActivity(ctx context.Context, metricName
470656 return []external_metrics.ExternalMetricValue {metric }, val > s .metadata .activationTargetValue , nil
471657}
472658
473- func getMetricAPIServerRequest (ctx context.Context , meta * metricsAPIScalerMetadata ) (* http.Request , error ) {
659+ func getMetricAPIServerRequest (ctx context.Context , meta * metricsAPIScalerMetadata , url * string ) (* http.Request , error ) {
474660 var req * http.Request
475661 var err error
476662
663+ if url == nil {
664+ url = & meta .url
665+ }
477666 switch {
478667 case meta .enableAPIKeyAuth :
479668 if meta .method == methodValueQuery {
480- url , _ := neturl .Parse (meta . url )
669+ url , _ := neturl .Parse (* url )
481670 queryString := url .Query ()
482671 if len (meta .keyParamName ) == 0 {
483672 queryString .Set ("api_key" , meta .apiKey )
@@ -492,7 +681,7 @@ func getMetricAPIServerRequest(ctx context.Context, meta *metricsAPIScalerMetada
492681 }
493682 } else {
494683 // default behaviour is to use header method
495- req , err = http .NewRequestWithContext (ctx , "GET" , meta . url , nil )
684+ req , err = http .NewRequestWithContext (ctx , "GET" , * url , nil )
496685 if err != nil {
497686 return nil , err
498687 }
@@ -504,20 +693,20 @@ func getMetricAPIServerRequest(ctx context.Context, meta *metricsAPIScalerMetada
504693 }
505694 }
506695 case meta .enableBaseAuth :
507- req , err = http .NewRequestWithContext (ctx , "GET" , meta . url , nil )
696+ req , err = http .NewRequestWithContext (ctx , "GET" , * url , nil )
508697 if err != nil {
509698 return nil , err
510699 }
511700
512701 req .SetBasicAuth (meta .username , meta .password )
513702 case meta .enableBearerAuth :
514- req , err = http .NewRequestWithContext (ctx , "GET" , meta . url , nil )
703+ req , err = http .NewRequestWithContext (ctx , "GET" , * url , nil )
515704 if err != nil {
516705 return nil , err
517706 }
518707 req .Header .Add ("Authorization" , fmt .Sprintf ("Bearer %s" , meta .bearerToken ))
519708 default :
520- req , err = http .NewRequestWithContext (ctx , "GET" , meta . url , nil )
709+ req , err = http .NewRequestWithContext (ctx , "GET" , * url , nil )
521710 if err != nil {
522711 return nil , err
523712 }
0 commit comments