Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,21 @@ Pods with requests in the queue will get score between 0.5 and 0.

---

#### ActiveRequestScorer

Scores pods based on the number of active requests being served per pod. Each request is tracked
individually with its own TTL to ensure accurate timeout handling. Pods with fewer active
requests receive higher scores.

Scores are normalized to a range of 0-1, where pods with fewer active requests get higher scores.

- **Type**: `active-request-scorer`
- **Parameters**:
- `requestTimeout`: specifies the timeout for requests in seconds. Once a request is "in-flight"
for this duration, it is considered timed out and automatically removed.

---

#### SessionAffinity

Scores the candidate pods by giving a higher score to the pods that were previously
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 // indirect
github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jellydator/ttlcache/v3 v3.4.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY=
github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down
1 change: 1 addition & 0 deletions pkg/plugins/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ func RegisterAllPlugins() {
plugins.Register(prefix.PrefixCachePluginType, scorer.PrefixCachePluginFactory)
plugins.Register(scorer.LoadAwareScorerType, scorer.LoadAwareScorerFactory)
plugins.Register(scorer.SessionAffinityScorerType, scorer.SessionAffinityScorerFactory)
plugins.Register(scorer.ActiveRequestScorerType, scorer.ActiveRequestScorerFactory)
}
246 changes: 246 additions & 0 deletions pkg/plugins/scorer/active_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package scorer

import (
"context"
"encoding/json"
"fmt"
"sync"
"time"

"github.com/jellydator/ttlcache/v3"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

const (
// ActiveRequestScorerType is the type of the ActiveRequestScorer
ActiveRequestScorerType = "active-request-scorer"

// defaultRequestTimeout defines the default timeout for open requests to be
// considered stale and removed from the cache.
defaultRequestTimeout = 2 * time.Minute
)

// ActiveRequestScorerParameters defines the parameters for the
// ActiveRequestScorer.
type ActiveRequestScorerParameters struct {
// RequestTimeout defines the timeout for requests in seconds.
// Once the request is "in-flight" for this duration, it is considered to
// be timed out and dropped.
// This field accepts duration strings like "30s", "1m", "2h".
RequestTimeout string `json:"requestTimeout"`
}

// requestEntry represents a single request in the cache
type requestEntry struct {
PodName string
RequestID string
}

// String returns a string representation of the request entry.
func (r *requestEntry) String() string {
return fmt.Sprintf("%s.%s", r.PodName, r.RequestID)
}

// compile-time type assertion
var _ framework.Scorer = &ActiveRequestScorer{}

// ActiveRequestScorerFactory defines the factory function for the ActiveRequestScorer.
func ActiveRequestScorerFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := ActiveRequestScorerParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", ActiveRequestScorerType, err)
}
}

return NewActiveRequestScorer(handle.Context(), &parameters).WithName(name), nil
}

// NewActiveRequestScorer creates a new ActiveRequestScorer scorer.
func NewActiveRequestScorer(ctx context.Context, params *ActiveRequestScorerParameters) *ActiveRequestScorer {
requestTimeout := defaultRequestTimeout
logger := log.FromContext(ctx)

if params != nil && params.RequestTimeout != "" {
paramsRequestTimeout, err := time.ParseDuration(params.RequestTimeout)
if err != nil || paramsRequestTimeout <= 0 {
logger.Error(err, "Invalid request timeout duration, using default request timeout")
} else {
requestTimeout = paramsRequestTimeout
logger.Info("Using request timeout", "requestTimeout", requestTimeout)
}
}

// cache for individual requests with their own TTL
requestCache := ttlcache.New[string, *requestEntry](
ttlcache.WithTTL[string, *requestEntry](requestTimeout),
ttlcache.WithDisableTouchOnHit[string, *requestEntry](),
)

scorer := &ActiveRequestScorer{
typedName: plugins.TypedName{Type: ActiveRequestScorerType},
requestCache: requestCache,
podCounts: make(map[string]int),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there seems to be a preference in IGW for sync,Map over map+sync.Mutex

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Map type is specialized. Most code should use a plain Go map instead, with separate locking or coordination, for better type safety and to make it easier to maintain other invariants along with the map content.
The Map type is optimized for two common use cases: (1) when the entry for a given key is only ever written once but read many times, as in caches that only grow, or (2) when multiple goroutines read, write, and overwrite entries for disjoint sets of keys. In these two cases, use of a Map may significantly reduce lock contention compared to a Go map paired with a separate Mutex or RWMutex.

I think our use does not fit this categorization - what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I made the same argument on IGW and ultimately decided it wasn't worth it.
It matters if you intend to contribute to IGW later, but can delay the switch (if at all) to that point in time.

mutex: &sync.RWMutex{},
}
// callback to decrement count when requests expire
// most requests will be removed in PostResponse, but this ensures
// that we don't leak pod counts if PostResponse is not called
requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason,
item *ttlcache.Item[string, *requestEntry]) {
if reason == ttlcache.EvictionReasonExpired {
scorer.decrementPodCount(item.Value().PodName)
}
})

go cleanCachePeriodically(ctx, requestCache, requestTimeout)

return scorer
}

// ActiveRequestScorer keeps track of individual requests being served
// per pod.
type ActiveRequestScorer struct {
typedName plugins.TypedName

// requestCache stores individual request entries with unique composite keys (podName.requestID)
requestCache *ttlcache.Cache[string, *requestEntry]

// podCounts maintains fast lookup for request counts per pod
podCounts map[string]int
mutex *sync.RWMutex
}

// TypedName returns the typed name of the plugin.
func (s *ActiveRequestScorer) TypedName() plugins.TypedName {
return s.typedName
}

// WithName sets the name of the plugin.
func (s *ActiveRequestScorer) WithName(name string) *ActiveRequestScorer {
s.typedName.Name = name
return s
}

// Score scores the given pods based on the number of active requests
// being served by each pod. The score is normalized to a range of 0-1.
func (s *ActiveRequestScorer) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest,
pods []types.Pod) map[types.Pod]float64 {
scoredPods := make(map[string]int)
maxCount := 0
s.mutex.RLock()
for podName, count := range s.podCounts {
scoredPods[podName] = count
if count >= maxCount {
maxCount = count
}
}
s.mutex.RUnlock()

scoredPodsMap := make(map[types.Pod]float64, len(pods))
for _, pod := range pods {
podName := pod.GetPod().NamespacedName.String()
if count, exists := scoredPods[podName]; exists {
if count == 0 {
scoredPodsMap[pod] = 1.0 // no requests means highest score
} else {
scoredPodsMap[pod] = float64(maxCount-count) / float64(maxCount)
}
} else {
scoredPodsMap[pod] = 1.0
}
}

log.FromContext(ctx).V(logutil.DEBUG).Info("Scored pods", "scores", scoredPodsMap)
return scoredPodsMap
}

// PreRequest is called before a request is sent to the target pod.
// It creates a new request entry in the cache with its own TTL and
// increments the pod count for fast lookup.
func (s *ActiveRequestScorer) PreRequest(ctx context.Context, request *types.LLMRequest,
schedulingResult *types.SchedulingResult, _ int) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG)

for _, profileResult := range schedulingResult.ProfileResults { // schedulingResult guaranteed not to be nil
if profileResult == nil || profileResult.TargetPods == nil || len(profileResult.TargetPods) == 0 {
continue
}

// create request entry for first pod only. TODO: support fallback pods
entry := &requestEntry{
PodName: profileResult.TargetPods[0].GetPod().NamespacedName.String(),
RequestID: request.RequestId,
}

// add to request cache with TTL
s.requestCache.Set(entry.String(), entry, 0) // Use default TTL
s.incrementPodCount(entry.PodName)

debugLogger.Info("Added request to cache", "requestEntry", entry.String())
}
}

// PostResponse is called after a response is sent to the client.
// It removes the specific request entry from the cache and decrements
// the pod count.
func (s *ActiveRequestScorer) PostResponse(ctx context.Context, request *types.LLMRequest,
_ *requestcontrol.Response, targetPod *backend.Pod) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequestScorer.PostResponse")
if targetPod == nil {
debugLogger.Info("Skipping PostResponse because targetPod is nil")
return
}

entry := requestEntry{targetPod.NamespacedName.String(), request.RequestId}

if _, found := s.requestCache.GetAndDelete(entry.String()); found {
s.decrementPodCount(entry.PodName)
Copy link

@VadimEisenberg VadimEisenberg Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vMaroon From reading the code, it seems PreRequest adds first target pods of each of ProfileResults, for example of both decode and prefill, while PostResponse removes only the primary target pod (of the decode).

The target pod in PostResponse is set to be the target pod of RequestCtx, which is the first target pod from the primary ProfileResult:

https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/8b154baffd35e1c4ad1dcd131f1cbcac04ddc304/pkg/epp/requestcontrol/director.go#L253C2-L263C34

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh - apologies, I misunderstood. Then you're right, it does not track that granularity in P/D.

We can definitely bump this in priority if you provide info. For example, given 2-points PostResponse hooks (cc @kfswain), the start can set prefill as done and the end updates decode.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vMaroon @kfswain maybe PostResponse should get target pods of all the profiles, and not only of the primary profile? Or canPostResponse receive SchedulingResult as a parameter, symmetrically to PreRequest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VadimEisenberg PostResponse intentionally receives the target pod who actually SERVED the request.
There is another factor that wasn't taken into consideration in this discussion -
EPP protocol allows to define multiple (prioritized) endpoints as candidates for serving.
more details here:
https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/004-endpoint-picker-protocol#destination-endpoint

this essentially tells Envoy that if the first endpoint in the list failed in serving the request, envoy should try to next endpoint in the list.
for all scorers (or other plugins) that maintain a state per request, it is useful to know which endpoint was serving the request and not which one was ranked first (and didn't necessarily served successfully).
This information should be reported by Gateways back to EPP:
https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/004-endpoint-picker-protocol#destination-endpoint-served

The above is still under development and therefore we currently use picker with configuration of MaxEndpoints=1.

let me know if it makes sense.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nirrozenbaum I see. My problem currently is that active-request-scorer will not work correctly with regard to prefill pods. Prefill pods will be incremented in the PreRequest, but will not be decremented in PostResponse. They will be removed by TTL eventually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VadimEisenberg right.
I think that until PostResponse issue is fixed, this scorer should either count only decode results, or it cannot be used as is.
cc: @vMaroon

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For PD I think we can do with the start/end PostResponse endpoints and the scorer can use request-id to maintain happy-path association of 1 P and 1 D per request.

debugLogger.Info("Removed request from cache", "requestEntry", entry.String())
} else {
debugLogger.Info("Request not found in cache", "requestEntry", entry.String())
}
}

// incrementPodCount increments the request count for a pod.
func (s *ActiveRequestScorer) incrementPodCount(podName string) {
s.mutex.Lock()
defer s.mutex.Unlock()

s.podCounts[podName]++
}

// decrementPodCount decrements the request count for a pod and removes
// the entry if count reaches zero.
func (s *ActiveRequestScorer) decrementPodCount(podName string) {
s.mutex.Lock()
defer s.mutex.Unlock()

if count, exists := s.podCounts[podName]; exists {
if count <= 1 {
delete(s.podCounts, podName)
} else {
s.podCounts[podName] = count - 1
}
}
}

func cleanCachePeriodically(ctx context.Context, cache *ttlcache.Cache[string, *requestEntry], requestTimeout time.Duration) {
ticker := time.NewTicker(requestTimeout)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cache.DeleteExpired()
}
}
}
Loading
Loading