diff --git a/docs/architecture.md b/docs/architecture.md index e9bf3cdc7..423c7ace1 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -355,6 +355,21 @@ Pods with requests in the queue will get score between 0.5 and 0. --- +#### ActiveRequestScorer + +Scores pods based on the number of active requests being served per pod. Each request is tracked +individually with its own TTL to ensure accurate timeout handling. Pods with fewer active +requests receive higher scores. + +Scores are normalized to a range of 0-1, where pods with fewer active requests get higher scores. + +- **Type**: `active-request-scorer` +- **Parameters**: + - `requestTimeout`: specifies the timeout for requests in seconds. Once a request is "in-flight" + for this duration, it is considered timed out and automatically removed. + +--- + #### SessionAffinity Scores the candidate pods by giving a higher score to the pods that were previously diff --git a/go.mod b/go.mod index c4764b6d4..f9346560d 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jellydator/ttlcache/v3 v3.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/mailru/easyjson v0.7.7 // indirect diff --git a/go.sum b/go.sum index 8e2fc5b2c..402a925f8 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= +github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go index 483634c08..bb635726f 100644 --- a/pkg/plugins/register.go +++ b/pkg/plugins/register.go @@ -21,4 +21,5 @@ func RegisterAllPlugins() { plugins.Register(prefix.PrefixCachePluginType, scorer.PrefixCachePluginFactory) plugins.Register(scorer.LoadAwareScorerType, scorer.LoadAwareScorerFactory) plugins.Register(scorer.SessionAffinityScorerType, scorer.SessionAffinityScorerFactory) + plugins.Register(scorer.ActiveRequestScorerType, scorer.ActiveRequestScorerFactory) } diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go new file mode 100644 index 000000000..11b2a9ade --- /dev/null +++ b/pkg/plugins/scorer/active_request.go @@ -0,0 +1,246 @@ +package scorer + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/jellydator/ttlcache/v3" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +const ( + // ActiveRequestScorerType is the type of the ActiveRequestScorer + ActiveRequestScorerType = "active-request-scorer" + + // defaultRequestTimeout defines the default timeout for open requests to be + // considered stale and removed from the cache. + defaultRequestTimeout = 2 * time.Minute +) + +// ActiveRequestScorerParameters defines the parameters for the +// ActiveRequestScorer. +type ActiveRequestScorerParameters struct { + // RequestTimeout defines the timeout for requests in seconds. + // Once the request is "in-flight" for this duration, it is considered to + // be timed out and dropped. + // This field accepts duration strings like "30s", "1m", "2h". + RequestTimeout string `json:"requestTimeout"` +} + +// requestEntry represents a single request in the cache +type requestEntry struct { + PodName string + RequestID string +} + +// String returns a string representation of the request entry. +func (r *requestEntry) String() string { + return fmt.Sprintf("%s.%s", r.PodName, r.RequestID) +} + +// compile-time type assertion +var _ framework.Scorer = &ActiveRequestScorer{} + +// ActiveRequestScorerFactory defines the factory function for the ActiveRequestScorer. +func ActiveRequestScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + parameters := ActiveRequestScorerParameters{} + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", ActiveRequestScorerType, err) + } + } + + return NewActiveRequestScorer(handle.Context(), ¶meters).WithName(name), nil +} + +// NewActiveRequestScorer creates a new ActiveRequestScorer scorer. +func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerParameters) *ActiveRequestScorer { + requestTimeout := defaultRequestTimeout + logger := log.FromContext(ctx) + + if params != nil && params.RequestTimeout != "" { + paramsRequestTimeout, err := time.ParseDuration(params.RequestTimeout) + if err != nil || paramsRequestTimeout <= 0 { + logger.Error(err, "Invalid request timeout duration, using default request timeout") + } else { + requestTimeout = paramsRequestTimeout + logger.Info("Using request timeout", "requestTimeout", requestTimeout) + } + } + + // cache for individual requests with their own TTL + requestCache := ttlcache.New[string, *requestEntry]( + ttlcache.WithTTL[string, *requestEntry](requestTimeout), + ttlcache.WithDisableTouchOnHit[string, *requestEntry](), + ) + + scorer := &ActiveRequestScorer{ + typedName: plugins.TypedName{Type: ActiveRequestScorerType}, + requestCache: requestCache, + podCounts: make(map[string]int), + mutex: &sync.RWMutex{}, + } + // callback to decrement count when requests expire + // most requests will be removed in PostResponse, but this ensures + // that we don't leak pod counts if PostResponse is not called + requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason, + item *ttlcache.Item[string, *requestEntry]) { + if reason == ttlcache.EvictionReasonExpired { + scorer.decrementPodCount(item.Value().PodName) + } + }) + + go cleanCachePeriodically(ctx, requestCache, requestTimeout) + + return scorer +} + +// ActiveRequestScorer keeps track of individual requests being served +// per pod. +type ActiveRequestScorer struct { + typedName plugins.TypedName + + // requestCache stores individual request entries with unique composite keys (podName.requestID) + requestCache *ttlcache.Cache[string, *requestEntry] + + // podCounts maintains fast lookup for request counts per pod + podCounts map[string]int + mutex *sync.RWMutex +} + +// TypedName returns the typed name of the plugin. +func (s *ActiveRequestScorer) TypedName() plugins.TypedName { + return s.typedName +} + +// WithName sets the name of the plugin. +func (s *ActiveRequestScorer) WithName(name string) *ActiveRequestScorer { + s.typedName.Name = name + return s +} + +// Score scores the given pods based on the number of active requests +// being served by each pod. The score is normalized to a range of 0-1. +func (s *ActiveRequestScorer) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest, + pods []types.Pod) map[types.Pod]float64 { + scoredPods := make(map[string]int) + maxCount := 0 + s.mutex.RLock() + for podName, count := range s.podCounts { + scoredPods[podName] = count + if count >= maxCount { + maxCount = count + } + } + s.mutex.RUnlock() + + scoredPodsMap := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + podName := pod.GetPod().NamespacedName.String() + if count, exists := scoredPods[podName]; exists { + if count == 0 { + scoredPodsMap[pod] = 1.0 // no requests means highest score + } else { + scoredPodsMap[pod] = float64(maxCount-count) / float64(maxCount) + } + } else { + scoredPodsMap[pod] = 1.0 + } + } + + log.FromContext(ctx).V(logutil.DEBUG).Info("Scored pods", "scores", scoredPodsMap) + return scoredPodsMap +} + +// PreRequest is called before a request is sent to the target pod. +// It creates a new request entry in the cache with its own TTL and +// increments the pod count for fast lookup. +func (s *ActiveRequestScorer) PreRequest(ctx context.Context, request *types.LLMRequest, + schedulingResult *types.SchedulingResult, _ int) { + debugLogger := log.FromContext(ctx).V(logutil.DEBUG) + + for _, profileResult := range schedulingResult.ProfileResults { // schedulingResult guaranteed not to be nil + if profileResult == nil || profileResult.TargetPods == nil || len(profileResult.TargetPods) == 0 { + continue + } + + // create request entry for first pod only. TODO: support fallback pods + entry := &requestEntry{ + PodName: profileResult.TargetPods[0].GetPod().NamespacedName.String(), + RequestID: request.RequestId, + } + + // add to request cache with TTL + s.requestCache.Set(entry.String(), entry, 0) // Use default TTL + s.incrementPodCount(entry.PodName) + + debugLogger.Info("Added request to cache", "requestEntry", entry.String()) + } +} + +// PostResponse is called after a response is sent to the client. +// It removes the specific request entry from the cache and decrements +// the pod count. +func (s *ActiveRequestScorer) PostResponse(ctx context.Context, request *types.LLMRequest, + _ *requestcontrol.Response, targetPod *backend.Pod) { + debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequestScorer.PostResponse") + if targetPod == nil { + debugLogger.Info("Skipping PostResponse because targetPod is nil") + return + } + + entry := requestEntry{targetPod.NamespacedName.String(), request.RequestId} + + if _, found := s.requestCache.GetAndDelete(entry.String()); found { + s.decrementPodCount(entry.PodName) + debugLogger.Info("Removed request from cache", "requestEntry", entry.String()) + } else { + debugLogger.Info("Request not found in cache", "requestEntry", entry.String()) + } +} + +// incrementPodCount increments the request count for a pod. +func (s *ActiveRequestScorer) incrementPodCount(podName string) { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.podCounts[podName]++ +} + +// decrementPodCount decrements the request count for a pod and removes +// the entry if count reaches zero. +func (s *ActiveRequestScorer) decrementPodCount(podName string) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if count, exists := s.podCounts[podName]; exists { + if count <= 1 { + delete(s.podCounts, podName) + } else { + s.podCounts[podName] = count - 1 + } + } +} + +func cleanCachePeriodically(ctx context.Context, cache *ttlcache.Cache[string, *requestEntry], requestTimeout time.Duration) { + ticker := time.NewTicker(requestTimeout) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cache.DeleteExpired() + } + } +} diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go new file mode 100644 index 000000000..25f8f3c7e --- /dev/null +++ b/pkg/plugins/scorer/active_request_test.go @@ -0,0 +1,304 @@ +package scorer + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestActiveRequestScorer_Score(t *testing.T) { + podA := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 2, + }, + } + podB := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 0, + }, + } + podC := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 15, + }, + } + + tests := []struct { + name string + setupCache func(*ActiveRequestScorer) + input []types.Pod + wantScores map[types.Pod]float64 + }{ + { + name: "no pods in cache", + setupCache: func(_ *ActiveRequestScorer) { + // Cache is empty + }, + input: []types.Pod{podA, podB, podC}, + wantScores: map[types.Pod]float64{ + podA: 1, + podB: 1, + podC: 1, + }, + }, + { + name: "all pods in cache with different request counts", + setupCache: func(s *ActiveRequestScorer) { + s.mutex.Lock() + s.podCounts["default/pod-a"] = 3 + s.podCounts["default/pod-b"] = 0 + s.podCounts["default/pod-c"] = 6 + s.mutex.Unlock() + }, + input: []types.Pod{podA, podB, podC}, + wantScores: map[types.Pod]float64{ + podA: 0.5, + podB: 1.0, + podC: 0.0, + }, + }, + { + name: "some pods in cache", + setupCache: func(s *ActiveRequestScorer) { + s.mutex.Lock() + s.podCounts["default/pod-a"] = 4 + s.podCounts["default/pod-c"] = 1 + // pod-b not in cache + s.mutex.Unlock() + }, + input: []types.Pod{podA, podB, podC}, + wantScores: map[types.Pod]float64{ + podA: 0.0, + podB: 1.0, + podC: 0.75, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scorer := NewActiveRequestScorer(context.Background(), nil) + test.setupCache(scorer) + + got := scorer.Score(context.Background(), nil, nil, test.input) + + if diff := cmp.Diff(test.wantScores, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +} + +func TestActiveRequestScorer_PreRequest(t *testing.T) { + ctx := context.Background() + + scorer := NewActiveRequestScorer(ctx, nil) + + podA := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 2, + }, + } + + request := &types.LLMRequest{ + RequestId: "test-request-1", + } + + schedulingResult := &types.SchedulingResult{ + ProfileResults: map[string]*types.ProfileRunResult{ + "test-profile": { + TargetPods: []types.Pod{podA}, + }, + }, + } + + // First request + scorer.PreRequest(ctx, request, schedulingResult, 0) + + // Check cache and pod counts + compositeKey := "default/pod-a.test-request-1" + if !scorer.requestCache.Has(compositeKey) { + t.Errorf("Expected request to be in cache with key %s", compositeKey) + } + + scorer.mutex.RLock() + count := scorer.podCounts["default/pod-a"] + scorer.mutex.RUnlock() + if count != 1 { + t.Errorf("Expected pod-a count to be 1, got %d", count) + } + + // Second request with different ID to same pod + request2 := &types.LLMRequest{ + RequestId: "test-request-2", + } + schedulingResult2 := &types.SchedulingResult{ + ProfileResults: map[string]*types.ProfileRunResult{ + "test-profile": { + TargetPods: []types.Pod{podA}, + }, + }, + } + + scorer.PreRequest(ctx, request2, schedulingResult2, 0) + + // Check incremented count + scorer.mutex.RLock() + count = scorer.podCounts["default/pod-a"] + scorer.mutex.RUnlock() + if count != 2 { + t.Errorf("Expected pod-a count to be 2, got %d", count) + } + + // Check both requests are in cache + compositeKey2 := "default/pod-a.test-request-2" + if !scorer.requestCache.Has(compositeKey2) { + t.Errorf("Expected second request to be in cache with key %s", compositeKey2) + } +} + +func TestActiveRequestScorer_PostResponse(t *testing.T) { + ctx := context.Background() + + scorer := NewActiveRequestScorer(ctx, nil) + + request := &types.LLMRequest{ + RequestId: "test-request-1", + } + + podA := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, + MetricsState: &backendmetrics.MetricsState{ + WaitingQueueSize: 2, + }, + } + // Setup initial state: add request through PreRequest + schedulingResult := &types.SchedulingResult{ + ProfileResults: map[string]*types.ProfileRunResult{ + "test-profile": { + TargetPods: []types.Pod{podA}, + }, + }, + } + + scorer.PreRequest(ctx, request, schedulingResult, 0) + + // Verify initial state + compositeKey := "default/pod-a.test-request-1" + if !scorer.requestCache.Has(compositeKey) { + t.Fatal("Request should be in cache before PostResponse") + } + + scorer.mutex.RLock() + initialCount := scorer.podCounts["default/pod-a"] + scorer.mutex.RUnlock() + if initialCount != 1 { + t.Fatalf("Expected initial count to be 1, got %d", initialCount) + } + + // Call PostResponse + scorer.PostResponse(ctx, request, &requestcontrol.Response{}, podA.GetPod()) + + // Check request is removed from cache + if scorer.requestCache.Has(compositeKey) { + t.Errorf("Request should be removed from cache after PostResponse") + } + + // Check pod count is decremented and removed (since it was 1) + scorer.mutex.RLock() + _, exists := scorer.podCounts["default/pod-a"] + scorer.mutex.RUnlock() + if exists { + t.Errorf("Pod should be removed from podCounts when count reaches 0") + } +} + +func TestActiveRequestScorer_TTLExpiration(t *testing.T) { + ctx := context.Background() + + // Use very short timeout for test + params := &ActiveRequestScorerParameters{RequestTimeout: "1s"} + scorer := NewActiveRequestScorer(ctx, params) // 1 second timeout + + request := &types.LLMRequest{ + RequestId: "test-request-ttl", + } + + podA := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, + } + + schedulingResult := &types.SchedulingResult{ + ProfileResults: map[string]*types.ProfileRunResult{ + "test-profile": { + TargetPods: []types.Pod{podA}, + }, + }, + } + + // Add request + scorer.PreRequest(ctx, request, schedulingResult, 0) + + // Verify request is added + scorer.mutex.RLock() + initialCount := scorer.podCounts["default/pod-a"] + scorer.mutex.RUnlock() + if initialCount != 1 { + t.Fatalf("Expected initial count to be 1, got %d", initialCount) + } + + // Wait for TTL expiration + time.Sleep(2 * time.Second) + + // Trigger cleanup + scorer.requestCache.DeleteExpired() + + // Check that pod count is decremented due to TTL expiration + scorer.mutex.RLock() + _, exists := scorer.podCounts["default/pod-a"] + scorer.mutex.RUnlock() + if exists { + t.Errorf("Pod should be removed from podCounts after TTL expiration") + } +} + +func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { + params := &ActiveRequestScorerParameters{RequestTimeout: "invalid"} + scorer := NewActiveRequestScorer(context.Background(), params) + + // Should use default timeout when invalid value is provided + if scorer == nil { + t.Error("Expected scorer to be created even with invalid timeout") + } +} + +func TestActiveRequestScorer_TypedName(t *testing.T) { + scorer := NewActiveRequestScorer(context.Background(), nil) + + typedName := scorer.TypedName() + if typedName.Type != ActiveRequestScorerType { + t.Errorf("Expected type %s, got %s", ActiveRequestScorerType, typedName.Type) + } +} + +func TestActiveRequestScorer_WithName(t *testing.T) { + scorer := NewActiveRequestScorer(context.Background(), nil) + testName := "test-scorer" + + scorer = scorer.WithName(testName) + + if scorer.TypedName().Name != testName { + t.Errorf("Expected name %s, got %s", testName, scorer.TypedName().Name) + } +}