+
\ No newline at end of file
diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go
index 83ce9a7fc..613ebf5ec 100644
--- a/pkg/epp/backend/metrics/fake.go
+++ b/pkg/epp/backend/metrics/fake.go
@@ -22,7 +22,6 @@ import (
"sync"
"time"
- corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -33,11 +32,8 @@ import (
// FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop.
type FakePodMetrics struct {
- Pod *backend.Pod
- Metrics *MetricsState
- runningRequests *datalayer.RequestPriorityQueue
- stopped bool
- mu sync.RWMutex // Protect the stopped field and operations
+ Pod *backend.Pod
+ Metrics *MetricsState
}
func (fpm *FakePodMetrics) String() string {
@@ -52,100 +48,8 @@ func (fpm *FakePodMetrics) GetMetrics() *MetricsState {
return fpm.Metrics
}
-func (fpm *FakePodMetrics) UpdatePod(pod *corev1.Pod) {
- fpm.Pod = toInternalPod(pod, nil)
-}
-
-func (f *FakePodMetrics) StopRefreshLoop() {
- f.mu.Lock()
- defer f.mu.Unlock()
- f.stopped = true
-}
-
-func (f *FakePodMetrics) GetRunningRequests() *datalayer.RequestPriorityQueue {
- f.mu.RLock()
- defer f.mu.RUnlock()
- if f.stopped {
- return nil // Return nil for stopped pod metrics
- }
- return f.runningRequests
-}
-
-func (f *FakePodMetrics) AddRequest(requestID string, tpot float64) bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
- if f.stopped {
- return false // Reject operations after stopped
- }
- return f.runningRequests.Add(requestID, tpot)
-}
-
-func (f *FakePodMetrics) RemoveRequest(requestID string) bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
- if f.stopped {
- return false // Reject operations after stopped
- }
- _, success := f.runningRequests.Remove(requestID)
- return success
-}
-
-func (f *FakePodMetrics) UpdateRequest(requestID string, tpot float64) bool {
- f.mu.RLock()
- defer f.mu.RUnlock()
- if f.stopped {
- return false // Reject operations after stopped
- }
- return f.runningRequests.Update(requestID, tpot)
-}
-
-func (f *FakePodMetrics) GetRequestCount() int {
- f.mu.RLock()
- defer f.mu.RUnlock()
- if f.stopped {
- return 0 // Return 0 after stopped
- }
- return f.runningRequests.GetSize()
-}
-
-func (f *FakePodMetrics) ContainsRequest(requestID string) bool {
- pod := f.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- return pod.RunningRequests.Contains(requestID)
-}
-
-func (srv *FakePodMetrics) PeekRequestPriorityQueue() *datalayer.Request {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return nil
- }
- return pod.RunningRequests.Peek()
-}
-
-func NewFakePodMetrics(k8sPod *corev1.Pod) *FakePodMetrics {
- labels := make(map[string]string)
- for k, v := range k8sPod.Labels {
- labels[k] = v
- }
-
- pod := &backend.Pod{
- NamespacedName: types.NamespacedName{
- Name: k8sPod.Name,
- Namespace: k8sPod.Namespace,
- },
- Address: k8sPod.Status.PodIP,
- Labels: labels,
- RunningRequests: datalayer.NewRequestPriorityQueue(),
- }
-
- return &FakePodMetrics{
- Pod: pod,
- Metrics: &MetricsState{UpdateTime: time.Now()},
- runningRequests: datalayer.NewRequestPriorityQueue(),
- stopped: false,
- }
+func (fpm *FakePodMetrics) UpdatePod(pod *datalayer.PodInfo) {
+ fpm.Pod = pod
}
func (*FakePodMetrics) Put(string, datalayer.Cloneable) {}
@@ -164,7 +68,7 @@ type FakePodMetricsClient struct {
Res map[types.NamespacedName]*MetricsState
}
-func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, _ int32) (*MetricsState, error) {
+func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
f.errMu.RLock()
err, ok := f.Err[pod.NamespacedName]
f.errMu.RUnlock()
diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go
index 8927b1b12..6f6be6820 100644
--- a/pkg/epp/backend/metrics/metrics.go
+++ b/pkg/epp/backend/metrics/metrics.go
@@ -25,6 +25,7 @@ import (
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
+ "github.com/prometheus/common/model"
"go.uber.org/multierr"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
@@ -35,11 +36,12 @@ const (
LoraInfoRunningAdaptersMetricName = "running_lora_adapters"
LoraInfoWaitingAdaptersMetricName = "waiting_lora_adapters"
LoraInfoMaxAdaptersMetricName = "max_lora"
+
+ CacheConfigBlockSizeInfoMetricName = "block_size"
)
type PodMetricsClientImpl struct {
MetricMapping *MetricMapping
- ModelServerMetricsPort int32
ModelServerMetricsPath string
ModelServerMetricsScheme string
@@ -47,8 +49,8 @@ type PodMetricsClientImpl struct {
}
// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one.
-func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error) {
- url := p.getMetricEndpoint(pod, port)
+func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error) {
+ url := p.getMetricEndpoint(pod)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
@@ -65,7 +67,7 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po
return nil, fmt.Errorf("unexpected status code from %s: %v", pod.NamespacedName, resp.StatusCode)
}
- parser := expfmt.TextParser{}
+ parser := expfmt.NewTextParser(model.LegacyValidation)
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil {
return nil, err
@@ -73,11 +75,8 @@ func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Po
return p.promToPodMetrics(metricFamilies, existing)
}
-func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod, targetPortNumber int32) string {
- if p.ModelServerMetricsPort == 0 {
- p.ModelServerMetricsPort = targetPortNumber
- }
- return fmt.Sprintf("%s://%s:%d%s", p.ModelServerMetricsScheme, pod.Address, p.ModelServerMetricsPort, p.ModelServerMetricsPath)
+func (p *PodMetricsClientImpl) getMetricEndpoint(pod *backend.Pod) string {
+ return p.ModelServerMetricsScheme + "://" + pod.GetMetricsHost() + p.ModelServerMetricsPath
}
// promToPodMetrics updates internal pod metrics with scraped Prometheus metrics.
@@ -152,6 +151,24 @@ func (p *PodMetricsClientImpl) promToPodMetrics(
}
}
+ if p.MetricMapping.CacheConfigInfo != nil {
+ cacheMetrics, err := p.getMetric(metricFamilies, *p.MetricMapping.CacheConfigInfo)
+ if err != nil {
+ errs = multierr.Append(errs, err)
+ } else {
+ for _, v := range cacheMetrics.GetLabel() {
+ if v.GetName() == CacheConfigBlockSizeInfoMetricName {
+ updated.CacheBlockSize, err = strconv.Atoi(v.GetValue())
+ if err != nil {
+ errs = multierr.Append(errs, err)
+ } else {
+ break
+ }
+ }
+ }
+ }
+ }
+
return updated, errs
}
diff --git a/pkg/epp/backend/metrics/metrics_spec.go b/pkg/epp/backend/metrics/metrics_spec.go
index 782f7427e..b3c26db2c 100644
--- a/pkg/epp/backend/metrics/metrics_spec.go
+++ b/pkg/epp/backend/metrics/metrics_spec.go
@@ -33,6 +33,7 @@ type MetricMapping struct {
TotalRunningRequests *MetricSpec
KVCacheUtilization *MetricSpec
LoraRequestInfo *MetricSpec
+ CacheConfigInfo *MetricSpec
}
// stringToMetricSpec converts a string to a MetricSpec.
@@ -94,7 +95,7 @@ func stringToMetricSpec(specStr string) (*MetricSpec, error) {
}
// NewMetricMapping creates a MetricMapping from string values.
-func NewMetricMapping(queuedStr, runningStr, kvUsageStr, loraReqInfoStr string) (*MetricMapping, error) {
+func NewMetricMapping(queuedStr, runningStr, kvUsageStr, loraReqInfoStr, cacheInfoMetric string) (*MetricMapping, error) {
queuedSpec, err := stringToMetricSpec(queuedStr)
if err != nil {
return nil, fmt.Errorf("error parsing WaitingRequests: %w", err)
@@ -111,11 +112,18 @@ func NewMetricMapping(queuedStr, runningStr, kvUsageStr, loraReqInfoStr string)
if err != nil {
return nil, fmt.Errorf("error parsing loraReqInfoStr: %w", err)
}
+
+ cacheInfoSpec, err := stringToMetricSpec(cacheInfoMetric)
+ if err != nil {
+ return nil, fmt.Errorf("error parsing cacheInfoMetric: %w", err)
+ }
+
mapping := &MetricMapping{
TotalQueuedRequests: queuedSpec,
TotalRunningRequests: runningSpec,
KVCacheUtilization: kvUsageSpec,
LoraRequestInfo: loraReqInfoSpec,
+ CacheConfigInfo: cacheInfoSpec,
}
return mapping, nil
diff --git a/pkg/epp/backend/metrics/metrics_test.go b/pkg/epp/backend/metrics/metrics_test.go
index 2dd8ca5dd..502ad6f09 100644
--- a/pkg/epp/backend/metrics/metrics_test.go
+++ b/pkg/epp/backend/metrics/metrics_test.go
@@ -489,7 +489,9 @@ func TestPromToPodMetrics(t *testing.T) {
func TestFetchMetrics(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())
pod := &backend.Pod{
- Address: "127.0.0.1",
+ Address: "127.0.0.1",
+ Port: "9999",
+ MetricsHost: "127.0.0.1:9999",
NamespacedName: types.NamespacedName{
Namespace: "test",
Name: "pod",
@@ -499,12 +501,11 @@ func TestFetchMetrics(t *testing.T) {
// No MetricMapping needed for this basic test
p := &PodMetricsClientImpl{
ModelServerMetricsScheme: "http",
- ModelServerMetricsPort: 9999,
ModelServerMetricsPath: "/metrics",
Client: http.DefaultClient,
}
- _, err := p.FetchMetrics(ctx, pod, existing, 9999) // Use a port that's unlikely to be in use
+ _, err := p.FetchMetrics(ctx, pod, existing) // Use a port that's unlikely to be in use
if err == nil {
t.Errorf("FetchMetrics() expected error, got nil")
}
diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go
index 9ee142610..a1114aecf 100644
--- a/pkg/epp/backend/metrics/pod_metrics.go
+++ b/pkg/epp/backend/metrics/pod_metrics.go
@@ -24,8 +24,6 @@ import (
"time"
"github.com/go-logr/logr"
- corev1 "k8s.io/api/core/v1"
- "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
@@ -51,7 +49,7 @@ type podMetrics struct {
}
type PodMetricsClient interface {
- FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState, port int32) (*MetricsState, error)
+ FetchMetrics(ctx context.Context, pod *backend.Pod, existing *MetricsState) (*MetricsState, error)
}
func (pm *podMetrics) String() string {
@@ -66,98 +64,8 @@ func (pm *podMetrics) GetMetrics() *MetricsState {
return pm.metrics.Load()
}
-// New methods for priority queue integration
-func (pm *podMetrics) GetRunningRequests() *datalayer.RequestPriorityQueue {
- pod := pm.GetPod()
- if pod == nil {
- return nil
- }
- return pod.RunningRequests
-}
-
-func (pm *podMetrics) AddRequest(requestID string, tpot float64) bool {
- pod := pm.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- success := pod.RunningRequests.Add(requestID, tpot)
- // No need to update metrics since we removed ActualRunningRequests
- return success
-}
-
-func (pm *podMetrics) RemoveRequest(requestID string) bool {
- pod := pm.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- _, success := pod.RunningRequests.Remove(requestID)
- // No need to update metrics since we removed ActualRunningRequests
- return success
-}
-
-func (pm *podMetrics) UpdateRequest(requestID string, tpot float64) bool {
- pod := pm.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- return pod.RunningRequests.Update(requestID, tpot)
-}
-
-func (pm *podMetrics) GetRequestCount() int {
- pod := pm.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return 0
- }
- return pod.RunningRequests.GetSize()
-}
-
-func (pm *podMetrics) ContainsRequest(requestID string) bool {
- pod := pm.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- return pod.RunningRequests.Contains(requestID)
-}
-
-func (pm *podMetrics) PeekRequestPriorityQueue() *datalayer.Request {
- pod := pm.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return nil
- }
- return pod.RunningRequests.Peek()
-}
-
-func (pm *podMetrics) UpdatePod(k8sPod *corev1.Pod) {
- currentPod := pm.GetPod()
- updatedPod := toInternalPod(k8sPod, currentPod.GetRunningRequests())
-
- // Preserve the existing running requests queue if it exists
- if currentPod != nil && currentPod.GetRunningRequests() != nil {
- updatedPod.RunningRequests = currentPod.GetRunningRequests()
- }
-
- pm.pod.Store(updatedPod)
-}
-func toInternalPod(pod *corev1.Pod, existingQueue *datalayer.RequestPriorityQueue) *backend.Pod {
- labels := make(map[string]string, len(pod.GetLabels()))
- for key, value := range pod.GetLabels() {
- labels[key] = value
- }
-
- queue := existingQueue
- if queue == nil {
- queue = datalayer.NewRequestPriorityQueue()
- }
-
- return &backend.Pod{
- NamespacedName: types.NamespacedName{
- Name: pod.Name,
- Namespace: pod.Namespace,
- },
- Address: pod.Status.PodIP,
- Labels: labels,
- RunningRequests: queue,
- }
+func (pm *podMetrics) UpdatePod(pod *datalayer.PodInfo) {
+ pm.pod.Store(pod)
}
// start starts a goroutine exactly once to periodically update metrics. The goroutine will be
@@ -185,17 +93,9 @@ func (pm *podMetrics) startRefreshLoop(ctx context.Context) {
}
func (pm *podMetrics) refreshMetrics() error {
- pool, err := pm.ds.PoolGet()
- if err != nil {
- // No inference pool or not initialize.
- return err
- }
ctx, cancel := context.WithTimeout(context.Background(), fetchMetricsTimeout)
defer cancel()
- if len(pool.Spec.TargetPorts) != 1 {
- return fmt.Errorf("expected 1 target port, got %d", len(pool.Spec.TargetPorts))
- }
- updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics(), int32(pool.Spec.TargetPorts[0].Number))
+ updated, err := pm.pmc.FetchMetrics(ctx, pm.GetPod(), pm.GetMetrics())
if err != nil {
pm.logger.V(logutil.TRACE).Info("Failed to refreshed metrics:", "err", err)
}
diff --git a/pkg/epp/backend/metrics/pod_metrics_test.go b/pkg/epp/backend/metrics/pod_metrics_test.go
index 49a1b3d2d..b0297cd1e 100644
--- a/pkg/epp/backend/metrics/pod_metrics_test.go
+++ b/pkg/epp/backend/metrics/pod_metrics_test.go
@@ -17,31 +17,25 @@ package metrics
import (
"context"
- "fmt"
- "sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/assert"
- corev1 "k8s.io/api/core/v1"
- metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)
var (
- pod1 = &corev1.Pod{
- ObjectMeta: metav1.ObjectMeta{
- Name: "pod1",
+ pod1Info = &datalayer.PodInfo{
+ NamespacedName: types.NamespacedName{
+ Name: "pod1-rank-0",
Namespace: "default",
- Labels: map[string]string{"app": "test"},
- },
- Status: corev1.PodStatus{
- PodIP: "192.168.1.1",
},
+ PodName: "pod1",
}
initial = &MetricsState{
WaitingQueueSize: 0,
@@ -71,12 +65,11 @@ func TestMetricsRefresh(t *testing.T) {
pmf := NewPodMetricsFactory(pmc, time.Millisecond)
// The refresher is initialized with empty metrics.
- pm := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
+ pm := pmf.NewEndpoint(ctx, pod1Info, &fakeDataStore{})
- namespacedName := types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace}
// Use SetRes to simulate an update of metrics from the pod.
// Verify that the metrics are updated.
- pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: initial})
+ pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: initial})
condition := func(collect *assert.CollectT) {
assert.True(collect, cmp.Equal(pm.GetMetrics(), initial, cmpopts.IgnoreFields(MetricsState{}, "UpdateTime")))
}
@@ -86,182 +79,11 @@ func TestMetricsRefresh(t *testing.T) {
// new update.
pmf.ReleaseEndpoint(pm)
time.Sleep(pmf.refreshMetricsInterval * 2 /* small buffer for robustness */)
- pmc.SetRes(map[types.NamespacedName]*MetricsState{namespacedName: updated})
+ pmc.SetRes(map[types.NamespacedName]*MetricsState{pod1Info.NamespacedName: updated})
// Still expect the same condition (no metrics update).
assert.EventuallyWithT(t, condition, time.Second, time.Millisecond)
}
-// Test priority queue functionality
-func TestPodMetricsRequestManagement(t *testing.T) {
- ctx := context.Background()
- pmc := &FakePodMetricsClient{}
- pmf := NewPodMetricsFactory(pmc, time.Minute) // Long interval to avoid interference
-
- pme := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
- pm := pme.(*podMetrics) // Type assertion to access podMetrics methods
-
- defer pmf.ReleaseEndpoint(pm)
-
- // Test adding requests
- assert.True(t, pm.AddRequest("req1", 1.5))
- assert.True(t, pm.AddRequest("req2", 2.0))
- assert.False(t, pm.AddRequest("req1", 1.0)) // Duplicate should fail
-
- // Test request count
- assert.Equal(t, 2, pm.GetRequestCount())
-
- // Test contains request
- assert.True(t, pm.ContainsRequest("req1"))
- assert.False(t, pm.ContainsRequest("req3"))
-
- // Test update request
- assert.True(t, pm.UpdateRequest("req1", 0.5))
- assert.False(t, pm.UpdateRequest("req3", 1.0)) // Non-existent
-
- // Test remove request
- assert.True(t, pm.RemoveRequest("req1"))
- assert.False(t, pm.RemoveRequest("req1")) // Already removed
- assert.Equal(t, 1, pm.GetRequestCount())
-
- // Test getting running requests queue
- queue := pm.GetRunningRequests()
- assert.NotNil(t, queue)
- assert.Equal(t, 1, queue.GetSize())
-}
-
-// Test pod updates preserve request queue
-func TestPodUpdatePreservesQueue(t *testing.T) {
- ctx := context.Background()
- pmc := &FakePodMetricsClient{}
- pmf := NewPodMetricsFactory(pmc, time.Minute)
-
- pme := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
- pm := pme.(*podMetrics) // Type assertion to access podMetrics methods
-
- defer pmf.ReleaseEndpoint(pm)
-
- // Add some requests
- assert.True(t, pm.AddRequest("req1", 1.5))
- assert.True(t, pm.AddRequest("req2", 2.0))
- assert.Equal(t, 2, pm.GetRequestCount())
-
- // Update pod with new IP
- updatedPod := pod1.DeepCopy()
- updatedPod.Status.PodIP = "192.168.1.2"
- updatedPod.Labels["new"] = "label"
-
- pm.UpdatePod(updatedPod)
-
- // Queue should be preserved
- assert.Equal(t, 2, pm.GetRequestCount())
- assert.True(t, pm.ContainsRequest("req1"))
- assert.True(t, pm.ContainsRequest("req2"))
-
- // Pod properties should be updated
- pod := pm.GetPod()
- assert.Equal(t, "192.168.1.2", pod.Address)
- assert.Equal(t, "label", pod.Labels["new"])
-}
-
-// Test error handling in metrics refresh
-func TestMetricsRefreshWithErrors(t *testing.T) {
- ctx := context.Background()
- pmc := &FakePodMetricsClient{}
- pmf := NewPodMetricsFactory(pmc, time.Millisecond)
-
- pme := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
- pm := pme.(*podMetrics) // Type assertion to access podMetrics methods
-
- defer pmf.ReleaseEndpoint(pm)
-
- namespacedName := types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace}
-
- // Set an error for this pod
- pmc.SetErr(map[types.NamespacedName]error{
- namespacedName: fmt.Errorf("connection failed"),
- })
-
- // Metrics should still be accessible (error is logged but not fatal)
- // The pod metrics should continue to work
- assert.NotNil(t, pm.GetMetrics())
- assert.NotNil(t, pm.GetPod())
-
- // Request operations should still work
- assert.True(t, pm.AddRequest("req1", 1.5))
- assert.Equal(t, 1, pm.GetRequestCount())
-}
-
-// Test string representation
-func TestPodMetricsString(t *testing.T) {
- ctx := context.Background()
- pmc := &FakePodMetricsClient{}
- pmf := NewPodMetricsFactory(pmc, time.Minute)
-
- pme := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
- pm := pme.(*podMetrics) // Type assertion to access podMetrics methods
-
- defer pmf.ReleaseEndpoint(pm)
-
- // Add some requests
- pm.AddRequest("req1", 1.5)
- pm.AddRequest("req2", 2.0)
-
- str := pm.String()
- assert.Contains(t, str, "pod1")
- assert.Contains(t, str, "default")
- assert.Contains(t, str, "[req1(1.50), req2(2.00)]")
- assert.Contains(t, str, "192.168.1.1")
-}
-
-// Test concurrent access to request operations
-func TestConcurrentRequestOperations(t *testing.T) {
- ctx := context.Background()
- pmc := &FakePodMetricsClient{}
- pmf := NewPodMetricsFactory(pmc, time.Minute)
-
- pme := pmf.NewEndpoint(ctx, pod1, &fakeDataStore{})
- pm := pme.(*podMetrics) // Type assertion to access podMetrics methods
-
- defer pmf.ReleaseEndpoint(pm)
-
- const numGoroutines = 10
- const requestsPerGoroutine = 100
-
- var wg sync.WaitGroup
-
- // Launch goroutines that add requests
- for i := 0; i < numGoroutines; i++ {
- wg.Add(1)
- go func(id int) {
- defer wg.Done()
- for j := 0; j < requestsPerGoroutine; j++ {
- requestID := fmt.Sprintf("req-%d-%d", id, j)
- pm.AddRequest(requestID, float64(j))
- }
- }(i)
- }
-
- // Launch goroutines that check and remove requests
- for i := 0; i < numGoroutines/2; i++ {
- wg.Add(1)
- go func(id int) {
- defer wg.Done()
- for j := 0; j < requestsPerGoroutine/2; j++ {
- requestID := fmt.Sprintf("req-%d-%d", id, j)
- if pm.ContainsRequest(requestID) {
- pm.RemoveRequest(requestID)
- }
- }
- }(i)
- }
-
- wg.Wait()
-
- // Should not crash and should have some requests remaining
- count := pm.GetRequestCount()
- assert.True(t, count >= 0) // Basic sanity check
-}
-
type fakeDataStore struct{}
func (f *fakeDataStore) PoolGet() (*v1.InferencePool, error) {
diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go
index cbb4dc7df..99f15a20f 100644
--- a/pkg/epp/backend/metrics/types.go
+++ b/pkg/epp/backend/metrics/types.go
@@ -22,7 +22,6 @@ import (
"sync"
"time"
- corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
@@ -53,8 +52,7 @@ type PodMetricsFactory struct {
refreshMetricsInterval time.Duration
}
-func (f *PodMetricsFactory) NewEndpoint(parentCtx context.Context, in *corev1.Pod, ds datalayer.PoolInfo) PodMetrics {
- pod := toInternalPod(in, nil) // Pass nil for new pod - will create new queue
+func (f *PodMetricsFactory) NewEndpoint(parentCtx context.Context, pod *datalayer.PodInfo, ds datalayer.PoolInfo) PodMetrics {
pm := &podMetrics{
pmc: f.pmc,
ds: ds,
diff --git a/pkg/epp/config/loader/configloader.go b/pkg/epp/config/loader/configloader.go
index 8e80b037d..865eae28b 100644
--- a/pkg/epp/config/loader/configloader.go
+++ b/pkg/epp/config/loader/configloader.go
@@ -31,6 +31,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/profile"
)
var scheme = runtime.NewScheme()
@@ -113,6 +114,10 @@ func loadSchedulerConfig(configProfiles []configapi.SchedulingProfile, handle pl
return nil, errors.New("no profile handler was specified")
}
+ if profileHandler.TypedName().Type == profile.SingleProfileHandlerType && len(profiles) > 1 {
+ return nil, errors.New("single profile handler is intended to be used with a single profile, but multiple profiles were specified")
+ }
+
return scheduling.NewSchedulerConfig(profileHandler, profiles), nil
}
diff --git a/pkg/epp/config/loader/configloader_test.go b/pkg/epp/config/loader/configloader_test.go
index 5bf5a6608..c00563ad3 100644
--- a/pkg/epp/config/loader/configloader_test.go
+++ b/pkg/epp/config/loader/configloader_test.go
@@ -73,7 +73,7 @@ func TestLoadRawConfiguration(t *testing.T) {
},
{
Type: test2Type,
- Parameters: json.RawMessage("{\"hashBlockSize\":32}"),
+ Parameters: json.RawMessage("{\"blockSize\":32}"),
},
{
Name: "testPicker",
@@ -175,7 +175,7 @@ func TestLoadRawConfigurationWithDefaults(t *testing.T) {
{
Name: test2Type,
Type: test2Type,
- Parameters: json.RawMessage("{\"hashBlockSize\":32}"),
+ Parameters: json.RawMessage("{\"blockSize\":32}"),
},
{
Name: "testPicker",
@@ -420,6 +420,11 @@ func TestLoadConfig(t *testing.T) {
configText: errorNoProfileHandlersText,
wantErr: true,
},
+ {
+ name: "errorMultiProfilesUseSingleProfileHandler",
+ configText: errorMultiProfilesUseSingleProfileHandlerText,
+ wantErr: true,
+ },
}
registerNeededPlgugins()
@@ -464,7 +469,7 @@ plugins:
type: test-profile-handler
- type: test-two
parameters:
- hashBlockSize: 32
+ blockSize: 32
- name: testPicker
type: test-picker
schedulingProfiles:
@@ -772,7 +777,7 @@ plugins:
- name: prefixCacheScorer
type: prefix-cache-scorer
parameters:
- hashBlockSize: 32
+ blockSize: 32
- name: maxScorePicker
type: max-score-picker
- name: profileHandler
@@ -797,7 +802,7 @@ plugins:
- name: prefixCacheScorer
type: prefix-cache-scorer
parameters:
- hashBlockSize: 32
+ blockSize: 32
schedulingProfiles:
- name: default
plugins:
@@ -831,7 +836,7 @@ plugins:
- name: prefixCacheScorer
type: prefix-cache-scorer
parameters:
- hashBlockSize: asdf
+ blockSize: asdf
schedulingProfiles:
- name: default
plugins:
@@ -895,3 +900,23 @@ schedulingProfiles:
plugins:
- pluginRef: maxScore
`
+
+// multiple profiles using SingleProfileHandler
+//
+//nolint:dupword
+const errorMultiProfilesUseSingleProfileHandlerText = `
+apiVersion: inference.networking.x-k8s.io/v1alpha1
+kind: EndpointPickerConfig
+plugins:
+- name: profileHandler
+ type: single-profile-handler
+- name: maxScore
+ type: max-score-picker
+schedulingProfiles:
+- name: default
+ plugins:
+ - pluginRef: maxScore
+- name: prof2
+ plugins:
+ - pluginRef: maxScore
+`
diff --git a/pkg/epp/controller/inferenceobjective_reconciler.go b/pkg/epp/controller/inferenceobjective_reconciler.go
index 53bce8646..c8ac5a6c3 100644
--- a/pkg/epp/controller/inferenceobjective_reconciler.go
+++ b/pkg/epp/controller/inferenceobjective_reconciler.go
@@ -18,6 +18,7 @@ package controller
import (
"context"
+ "fmt"
"k8s.io/apimachinery/pkg/api/errors"
ctrl "sigs.k8s.io/controller-runtime"
@@ -48,8 +49,7 @@ func (c *InferenceObjectiveReconciler) Reconcile(ctx context.Context, req ctrl.R
notFound := false
if err := c.Get(ctx, req.NamespacedName, infObjective); err != nil {
if !errors.IsNotFound(err) {
- logger.Error(err, "Unable to get InferenceObjective")
- return ctrl.Result{}, err
+ return ctrl.Result{}, fmt.Errorf("unable to get InferenceObjective - %w", err)
}
notFound = true
}
diff --git a/pkg/epp/controller/inferenceobjective_reconciler_test.go b/pkg/epp/controller/inferenceobjective_reconciler_test.go
index de43d6e63..4ceff5d07 100644
--- a/pkg/epp/controller/inferenceobjective_reconciler_test.go
+++ b/pkg/epp/controller/inferenceobjective_reconciler_test.go
@@ -160,7 +160,7 @@ func TestInferenceObjectiveReconciler(t *testing.T) {
WithObjects(initObjs...).
Build()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- ds := datastore.NewDatastore(t.Context(), pmf)
+ ds := datastore.NewDatastore(t.Context(), pmf, 0)
for _, m := range test.objectivessInStore {
ds.ObjectiveSet(m)
}
diff --git a/pkg/epp/controller/inferencepool_reconciler.go b/pkg/epp/controller/inferencepool_reconciler.go
index d8b7668e2..3b52de0ae 100644
--- a/pkg/epp/controller/inferencepool_reconciler.go
+++ b/pkg/epp/controller/inferencepool_reconciler.go
@@ -56,9 +56,7 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques
obj = &v1alpha2.InferencePool{}
default:
// Handle unsupported groups gracefully.
- err := fmt.Errorf("unsupported API group: %s", c.PoolGKNN.Group)
- logger.Error(err, "Cannot reconcile InferencePool")
- return ctrl.Result{}, err
+ return ctrl.Result{}, fmt.Errorf("cannot reconcile InferencePool - unsupported API group: %s", c.PoolGKNN.Group)
}
// 2. Perform a single, generic fetch for the object.
@@ -68,8 +66,7 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques
c.Datastore.Clear()
return ctrl.Result{}, nil
}
- logger.Error(err, "Unable to get InferencePool")
- return ctrl.Result{}, err
+ return ctrl.Result{}, fmt.Errorf("unable to get InferencePool - %w", err)
}
// 3. Perform common checks using the client.Object interface.
@@ -90,16 +87,14 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques
var err error
err = pool.ConvertTo(v1infPool)
if err != nil {
- logger.Error(err, "Failed to convert XInferencePool to InferencePool")
- return ctrl.Result{}, err
+ return ctrl.Result{}, fmt.Errorf("failed to convert XInferencePool to InferencePool - %w", err)
}
default:
return ctrl.Result{}, fmt.Errorf("unsupported API group: %s", c.PoolGKNN.Group)
}
if err := c.Datastore.PoolSet(ctx, c.Reader, v1infPool); err != nil {
- logger.Error(err, "Failed to update datastore")
- return ctrl.Result{}, err
+ return ctrl.Result{}, fmt.Errorf("failed to update datastore - %w", err)
}
return ctrl.Result{}, nil
diff --git a/pkg/epp/controller/inferencepool_reconciler_test.go b/pkg/epp/controller/inferencepool_reconciler_test.go
index 7f6938533..a2bce1256 100644
--- a/pkg/epp/controller/inferencepool_reconciler_test.go
+++ b/pkg/epp/controller/inferencepool_reconciler_test.go
@@ -24,6 +24,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
corev1 "k8s.io/api/core/v1"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
@@ -113,14 +114,14 @@ func TestInferencePoolReconciler(t *testing.T) {
ctx := context.Background()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- datastore := datastore.NewDatastore(ctx, pmf)
- inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: datastore, PoolGKNN: gknn}
+ ds := datastore.NewDatastore(ctx, pmf, 0)
+ inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: ds, PoolGKNN: gknn}
// Step 1: Inception, only ready pods matching pool1 are added to the store.
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := diffStore(datastore, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" {
+ if diff := diffStore(ds, diffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
@@ -138,7 +139,7 @@ func TestInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
+ if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
@@ -153,7 +154,7 @@ func TestInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := diffStore(datastore, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
+ if diff := diffStore(ds, diffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
@@ -167,7 +168,7 @@ func TestInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := diffStore(datastore, diffStoreParams{wantPods: []string{}}); diff != "" {
+ if diff := diffStore(ds, diffStoreParams{wantPods: []string{}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
}
@@ -180,7 +181,9 @@ type diffStoreParams struct {
func diffStore(datastore datastore.Datastore, params diffStoreParams) string {
gotPool, _ := datastore.PoolGet()
- if diff := cmp.Diff(params.wantPool, gotPool); diff != "" {
+ // controller-runtime fake client may not populate TypeMeta (APIVersion/Kind).
+ // Ignore it when comparing pools.
+ if diff := cmp.Diff(params.wantPool, gotPool, cmpopts.IgnoreTypes(metav1.TypeMeta{})); diff != "" {
return "pool:" + diff
}
@@ -258,14 +261,14 @@ func TestXInferencePoolReconciler(t *testing.T) {
ctx := context.Background()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- datastore := datastore.NewDatastore(ctx, pmf)
- inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: datastore, PoolGKNN: gknn}
+ ds := datastore.NewDatastore(ctx, pmf, 0)
+ inferencePoolReconciler := &InferencePoolReconciler{Reader: fakeClient, Datastore: ds, PoolGKNN: gknn}
// Step 1: Inception, only ready pods matching pool1 are added to the store.
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1", "pod2"}}); diff != "" {
+ if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: pool1, wantPods: []string{"pod1-rank-0", "pod2-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
@@ -281,7 +284,7 @@ func TestXInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
+ if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
@@ -296,7 +299,7 @@ func TestXInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5"}}); diff != "" {
+ if diff := xDiffStore(t, ds, xDiffStoreParams{wantPool: newPool1, wantPods: []string{"pod5-rank-0"}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
@@ -310,7 +313,7 @@ func TestXInferencePoolReconciler(t *testing.T) {
if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil {
t.Errorf("Unexpected InferencePool reconcile error: %v", err)
}
- if diff := xDiffStore(t, datastore, xDiffStoreParams{wantPods: []string{}}); diff != "" {
+ if diff := xDiffStore(t, ds, xDiffStoreParams{wantPods: []string{}}); diff != "" {
t.Errorf("Unexpected diff (+got/-want): %s", diff)
}
}
@@ -333,7 +336,10 @@ func xDiffStore(t *testing.T, datastore datastore.Datastore, params xDiffStorePa
if err != nil {
t.Fatalf("failed to convert InferencePool to XInferencePool: %v", err)
}
- if diff := cmp.Diff(params.wantPool, gotXPool); diff != "" {
+
+ // controller-runtime fake client may not populate TypeMeta (APIVersion/Kind).
+ // Ignore it when comparing pools.
+ if diff := cmp.Diff(params.wantPool, gotXPool, cmpopts.IgnoreTypes(metav1.TypeMeta{})); diff != "" {
return "pool:" + diff
}
diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go
index 3cd7c2574..b3a78ef92 100644
--- a/pkg/epp/controller/pod_reconciler.go
+++ b/pkg/epp/controller/pod_reconciler.go
@@ -18,11 +18,11 @@ package controller
import (
"context"
+ "fmt"
"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
- "k8s.io/apimachinery/pkg/types"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/event"
@@ -52,11 +52,10 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R
pod := &corev1.Pod{}
if err := c.Get(ctx, req.NamespacedName, pod); err != nil {
if apierrors.IsNotFound(err) {
- c.Datastore.PodDelete(req.NamespacedName)
+ c.Datastore.PodDelete(req.Name)
return ctrl.Result{}, nil
}
- logger.V(logutil.DEFAULT).Error(err, "Unable to get pod")
- return ctrl.Result{}, err
+ return ctrl.Result{}, fmt.Errorf("unable to get pod - %w", err)
}
c.updateDatastore(logger, pod)
@@ -90,10 +89,9 @@ func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error {
}
func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) {
- namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace}
if !podutil.IsPodReady(pod) || !c.Datastore.PoolLabelsMatch(pod.Labels) {
logger.V(logutil.DEBUG).Info("Pod removed or not added")
- c.Datastore.PodDelete(namespacedName)
+ c.Datastore.PodDelete(pod.Name)
} else {
if c.Datastore.PodUpdateOrAddIfNotExist(pod) {
logger.V(logutil.DEFAULT).Info("Pod added")
diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go
index 5ceb3efdb..28f817310 100644
--- a/pkg/epp/controller/pod_reconciler_test.go
+++ b/pkg/epp/controller/pod_reconciler_test.go
@@ -196,7 +196,7 @@ func TestPodReconciler(t *testing.T) {
Build()
// Configure the initial state of the datastore.
- store := datastore.NewDatastore(t.Context(), pmf)
+ store := datastore.NewDatastore(t.Context(), pmf, 0)
_ = store.PoolSet(t.Context(), fakeClient, test.pool)
for _, pod := range test.existingPods {
store.PodUpdateOrAddIfNotExist(pod)
@@ -213,7 +213,7 @@ func TestPodReconciler(t *testing.T) {
var gotPods []*corev1.Pod
for _, pm := range store.PodList(backendmetrics.AllPodsPredicate) {
- pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}}
+ pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().PodName, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().GetIPAddress()}}
gotPods = append(gotPods, pod)
}
if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) {
diff --git a/pkg/epp/datalayer/collector_test.go b/pkg/epp/datalayer/collector_test.go
index 2d47de30a..0e3b9151b 100644
--- a/pkg/epp/datalayer/collector_test.go
+++ b/pkg/epp/datalayer/collector_test.go
@@ -24,8 +24,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- corev1 "k8s.io/api/core/v1"
- metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/mocks"
)
@@ -45,14 +44,12 @@ func (d *DummySource) Collect(ctx context.Context, ep Endpoint) error {
func defaultEndpoint() Endpoint {
ms := NewEndpoint()
- pod := &corev1.Pod{
- ObjectMeta: metav1.ObjectMeta{
+ pod := &PodInfo{
+ NamespacedName: types.NamespacedName{
Name: "pod-name",
Namespace: "default",
},
- Status: corev1.PodStatus{
- PodIP: "1.2.3.4",
- },
+ Address: "1.2.3.4:5678",
}
ms.UpdatePod(pod)
return ms
diff --git a/pkg/epp/datalayer/endpoint.go b/pkg/epp/datalayer/endpoint.go
index 7898a7a41..74c11905e 100644
--- a/pkg/epp/datalayer/endpoint.go
+++ b/pkg/epp/datalayer/endpoint.go
@@ -19,14 +19,12 @@ package datalayer
import (
"fmt"
"sync/atomic"
-
- corev1 "k8s.io/api/core/v1"
)
// EndpointPodState allows management of the Pod related attributes.
type EndpointPodState interface {
GetPod() *PodInfo
- UpdatePod(*corev1.Pod)
+ UpdatePod(*PodInfo)
}
// EndpointMetricsState allows management of the Metrics related attributes.
@@ -35,23 +33,11 @@ type EndpointMetricsState interface {
UpdateMetrics(*Metrics)
}
-// EndpointRunningRequestsState allows management of the Pod related attributes.
-type EndpointRunningRequestsState interface {
- GetRunningRequests() *RequestPriorityQueue
- AddRequest(requestID string, tpot float64) bool
- RemoveRequest(requestID string) bool
- UpdateRequest(requestID string, tpot float64) bool
- GetRequestCount() int
- ContainsRequest(requestID string) bool
- PeekRequestPriorityQueue() *Request
-}
-
// Endpoint represents an inference serving endpoint and its related attributes.
type Endpoint interface {
fmt.Stringer
EndpointPodState
EndpointMetricsState
- EndpointRunningRequestsState
AttributeMap
}
@@ -79,16 +65,8 @@ func (srv *ModelServer) GetPod() *PodInfo {
return srv.pod.Load()
}
-func (srv *ModelServer) UpdatePod(k8sPod *corev1.Pod) {
- currentPod := srv.GetPod()
- updatedPod := ToPodInfo(k8sPod)
-
- // Preserve the existing running requests queue if it exists
- if currentPod != nil && currentPod.GetRunningRequests() != nil {
- updatedPod.RunningRequests = currentPod.GetRunningRequests()
- }
-
- srv.pod.Store(updatedPod)
+func (srv *ModelServer) UpdatePod(pod *PodInfo) {
+ srv.pod.Store(pod)
}
func (srv *ModelServer) GetMetrics() *Metrics {
@@ -99,67 +77,6 @@ func (srv *ModelServer) UpdateMetrics(metrics *Metrics) {
srv.metrics.Store(metrics)
}
-// New methods for priority queue integration
-func (srv *ModelServer) GetRunningRequests() *RequestPriorityQueue {
- pod := srv.GetPod()
- if pod == nil {
- return nil
- }
- return pod.RunningRequests
-}
-
-func (srv *ModelServer) AddRequest(requestID string, tpot float64) bool {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- success := pod.RunningRequests.Add(requestID, tpot)
- // No need to update metrics since we removed ActualRunningRequests
- return success
-}
-
-func (srv *ModelServer) RemoveRequest(requestID string) bool {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- _, success := pod.RunningRequests.Remove(requestID)
- // No need to update metrics since we removed ActualRunningRequests
- return success
-}
-
-func (srv *ModelServer) UpdateRequest(requestID string, tpot float64) bool {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- return pod.RunningRequests.Update(requestID, tpot)
-}
-
-func (srv *ModelServer) GetRequestCount() int {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return 0
- }
- return pod.RunningRequests.GetSize()
-}
-
-func (srv *ModelServer) ContainsRequest(requestID string) bool {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return false
- }
- return pod.RunningRequests.Contains(requestID)
-}
-
-func (srv *ModelServer) PeekRequestPriorityQueue() *Request {
- pod := srv.GetPod()
- if pod == nil || pod.RunningRequests == nil {
- return nil
- }
- return pod.RunningRequests.Peek()
-}
-
func (srv *ModelServer) Put(key string, value Cloneable) {
srv.attributes.Put(key, value)
}
diff --git a/pkg/epp/datalayer/factory.go b/pkg/epp/datalayer/factory.go
index eca7697e5..989527c6c 100644
--- a/pkg/epp/datalayer/factory.go
+++ b/pkg/epp/datalayer/factory.go
@@ -21,7 +21,6 @@ import (
"sync"
"time"
- corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -45,7 +44,7 @@ type PoolInfo interface {
// providing methods to allocate and retire endpoints. This can potentially be used for
// pooled memory or other management chores in the implementation.
type EndpointFactory interface {
- NewEndpoint(parent context.Context, inpod *corev1.Pod, poolinfo PoolInfo) Endpoint
+ NewEndpoint(parent context.Context, inpod *PodInfo, poolinfo PoolInfo) Endpoint
ReleaseEndpoint(ep Endpoint)
}
@@ -70,8 +69,8 @@ func NewEndpointFactory(sources []DataSource, refreshMetricsInterval time.Durati
// NewEndpoint implements EndpointFactory.NewEndpoint.
// Creates a new endpoint and starts its associated collector with its own ticker.
// Guards against multiple concurrent calls for the same endpoint.
-func (lc *EndpointLifecycle) NewEndpoint(parent context.Context, inpod *corev1.Pod, _ PoolInfo) Endpoint {
- key := types.NamespacedName{Namespace: inpod.Namespace, Name: inpod.Name}
+func (lc *EndpointLifecycle) NewEndpoint(parent context.Context, inpod *PodInfo, _ PoolInfo) Endpoint {
+ key := types.NamespacedName{Namespace: inpod.GetNamespacedName().Namespace, Name: inpod.GetNamespacedName().Name}
logger := log.FromContext(parent).WithValues("pod", key)
if _, ok := lc.collectors.Load(key); ok {
diff --git a/pkg/epp/datalayer/metrics.go b/pkg/epp/datalayer/metrics.go
index 5869165c9..2febcb4d0 100644
--- a/pkg/epp/datalayer/metrics.go
+++ b/pkg/epp/datalayer/metrics.go
@@ -32,6 +32,7 @@ type Metrics struct {
WaitingQueueSize int
KVCacheUsagePercent float64
KvCacheMaxTokenCapacity int
+ CacheBlockSize int
// UpdateTime records the last time when the metrics were updated.
UpdateTime time.Time
@@ -75,6 +76,7 @@ func (m *Metrics) Clone() *Metrics {
WaitingQueueSize: m.WaitingQueueSize,
KVCacheUsagePercent: m.KVCacheUsagePercent,
KvCacheMaxTokenCapacity: m.KvCacheMaxTokenCapacity,
+ CacheBlockSize: m.CacheBlockSize,
UpdateTime: m.UpdateTime,
}
}
diff --git a/pkg/epp/datalayer/metrics/client.go b/pkg/epp/datalayer/metrics/client.go
index 962a2a584..c59850ac5 100644
--- a/pkg/epp/datalayer/metrics/client.go
+++ b/pkg/epp/datalayer/metrics/client.go
@@ -24,6 +24,7 @@ import (
"time"
"github.com/prometheus/common/expfmt"
+ "github.com/prometheus/common/model"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)
@@ -83,7 +84,7 @@ func (cl *client) Get(ctx context.Context, target *url.URL, ep datalayer.Address
return nil, fmt.Errorf("unexpected status code from %s: %v", ep.GetNamespacedName(), resp.StatusCode)
}
- parser := expfmt.TextParser{}
+ parser := expfmt.NewTextParser(model.LegacyValidation)
metricFamilies, err := parser.TextToMetricFamilies(resp.Body)
if err != nil {
return nil, err
diff --git a/pkg/epp/datalayer/metrics/datasource.go b/pkg/epp/datalayer/metrics/datasource.go
index 7dcdc97ba..1e14d1b1a 100644
--- a/pkg/epp/datalayer/metrics/datasource.go
+++ b/pkg/epp/datalayer/metrics/datasource.go
@@ -21,11 +21,8 @@ import (
"crypto/tls"
"errors"
"fmt"
- "net"
"net/url"
- "strconv"
"sync"
- "sync/atomic"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)
@@ -37,9 +34,8 @@ const (
// DataSource is a Model Server Protocol (MSP) compliant metrics data source,
// returning Prometheus formatted metrics for an endpoint.
type DataSource struct {
- metricsScheme string // scheme to use in metrics URL
- metricsPort atomic.Pointer[string] // target port to use in metrics URL
- metricsPath string // path to use in metrics URL
+ metricsScheme string // scheme to use in metrics URL
+ metricsPath string // path to use in metrics URL
client Client // client (e.g. a wrapped http.Client) used to get metrics
extractors sync.Map // key: name, value: extractor
@@ -49,7 +45,7 @@ type DataSource struct {
// the provided client factory. If ClientFactory is nil, a default factory is used.
// The Scheme, port and path are command line options. It should be noted that
// a port value of zero is set if the command line is unspecified.
-func NewDataSource(metricsScheme string, metricsPort int32, metricsPath string, skipCertVerification bool, cl Client) *DataSource {
+func NewDataSource(metricsScheme string, metricsPath string, skipCertVerification bool, cl Client) *DataSource {
if metricsScheme == "https" {
httpsTransport := baseTransport.Clone()
httpsTransport.TLSClientConfig = &tls.Config{
@@ -67,25 +63,9 @@ func NewDataSource(metricsScheme string, metricsPort int32, metricsPath string,
metricsPath: metricsPath,
client: cl,
}
- dataSrc.SetPort(metricsPort)
return dataSrc
}
-// SetPort updates the port used for metrics scraping.
-// The port value can only be set once (i.e., if set by command line,
-// do not overwrite from Pool.Spec). A port value of 0 (i.e., unspecified
-// command line value) is ignored.
-// TODO: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1398
-func (dataSrc *DataSource) SetPort(metricsPort int32) {
- if dataSrc.metricsPort.Load() != nil { // do not overwrite
- return
- }
- if metricsPort != 0 { // ignore zero value for port
- port := strconv.Itoa(int(metricsPort))
- dataSrc.metricsPort.Store(&port)
- }
-}
-
// Name returns the metrics data source name.
func (dataSrc *DataSource) Name() string {
return DataSourceName
@@ -132,7 +112,7 @@ func (dataSrc *DataSource) Collect(ctx context.Context, ep datalayer.Endpoint) e
func (dataSrc *DataSource) getMetricsEndpoint(ep datalayer.Addressable) *url.URL {
return &url.URL{
Scheme: dataSrc.metricsScheme,
- Host: net.JoinHostPort(ep.GetIPAddress(), *dataSrc.metricsPort.Load()),
+ Host: ep.GetMetricsHost(),
Path: dataSrc.metricsPath,
}
}
diff --git a/pkg/epp/datalayer/metrics/extractor.go b/pkg/epp/datalayer/metrics/extractor.go
index 08105196d..6c6978c87 100644
--- a/pkg/epp/datalayer/metrics/extractor.go
+++ b/pkg/epp/datalayer/metrics/extractor.go
@@ -37,6 +37,8 @@ const (
LoraInfoRunningAdaptersMetricName = "running_lora_adapters"
LoraInfoWaitingAdaptersMetricName = "waiting_lora_adapters"
LoraInfoMaxAdaptersMetricName = "max_lora"
+
+ CacheConfigBlockSizeInfoMetricName = "block_size"
)
// Extractor implements the metrics extraction based on the model
@@ -49,8 +51,8 @@ type Extractor struct {
// configured with the given metrics' specifications.
// These are mandatory metrics per the MSP specification, and are used
// as the basis for the built-in scheduling plugins.
-func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec string) (*Extractor, error) {
- mapping, err := NewMapping(queueSpec, runningSpec, kvusageSpec, loraSpec)
+func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec string) (*Extractor, error) {
+ mapping, err := NewMapping(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec)
if err != nil {
return nil, fmt.Errorf("failed to create extractor metrics Mapping - %w", err)
}
@@ -120,6 +122,16 @@ func (ext *Extractor) Extract(ctx context.Context, data any, ep datalayer.Endpoi
}
}
+ if spec := ext.mapping.CacheInfo; spec != nil { // extract CacheInfo-specific metrics
+ metric, err := spec.getLatestMetric(families)
+ if err != nil {
+ errs = append(errs, err)
+ } else if metric != nil {
+ populateCacheInfoMetrics(clone, metric, &errs)
+ updated = true
+ }
+ }
+
if updated {
clone.UpdateTime = time.Now()
ep.UpdateMetrics(clone)
@@ -154,6 +166,23 @@ func populateLoRAMetrics(clone *datalayer.Metrics, metric *dto.Metric, errs *[]e
}
}
+// populateCacheInfoMetrics updates the metrics with cache info from the metric labels.
+func populateCacheInfoMetrics(clone *datalayer.Metrics, metric *dto.Metric, errs *[]error) {
+ clone.CacheBlockSize = 0
+ for _, label := range metric.GetLabel() {
+ if label.GetName() == CacheConfigBlockSizeInfoMetricName {
+ if label.GetValue() != "" {
+ if val, err := strconv.Atoi(label.GetValue()); err == nil {
+ clone.CacheBlockSize = val
+ break
+ } else {
+ *errs = append(*errs, err)
+ }
+ }
+ }
+ }
+}
+
// addAdapters splits a comma-separated adapter list and stores keys with default value 0.
func addAdapters(m map[string]int, csv string) {
for _, name := range strings.Split(csv, ",") {
diff --git a/pkg/epp/datalayer/metrics/mapping.go b/pkg/epp/datalayer/metrics/mapping.go
index e92f1f102..7b1fed9c1 100644
--- a/pkg/epp/datalayer/metrics/mapping.go
+++ b/pkg/epp/datalayer/metrics/mapping.go
@@ -27,10 +27,11 @@ type Mapping struct {
TotalRunningRequests *Spec
KVCacheUtilization *Spec
LoraRequestInfo *LoRASpec
+ CacheInfo *Spec
}
// NewMapping creates a metrics.Mapping from the input specification strings.
-func NewMapping(queue, running, kvusage, lora string) (*Mapping, error) {
+func NewMapping(queue, running, kvusage, lora, cacheInfo string) (*Mapping, error) {
var errs []error
queueSpec, err := parseStringToSpec(queue)
@@ -49,6 +50,12 @@ func NewMapping(queue, running, kvusage, lora string) (*Mapping, error) {
if err != nil {
errs = append(errs, err)
}
+
+ cacheInfoSpec, err := parseStringToSpec(cacheInfo)
+ if err != nil {
+ errs = append(errs, err)
+ }
+
if len(errs) != 0 {
return nil, errors.Join(errs...)
}
@@ -57,5 +64,6 @@ func NewMapping(queue, running, kvusage, lora string) (*Mapping, error) {
TotalRunningRequests: runningSpec,
KVCacheUtilization: kvusageSpec,
LoraRequestInfo: loraSpec,
+ CacheInfo: cacheInfoSpec,
}, nil
}
diff --git a/pkg/epp/datalayer/podinfo.go b/pkg/epp/datalayer/podinfo.go
index 5f2d417c6..7cbd6d886 100644
--- a/pkg/epp/datalayer/podinfo.go
+++ b/pkg/epp/datalayer/podinfo.go
@@ -19,40 +19,25 @@ package datalayer
import (
"fmt"
- corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
)
// Addressable supports getting an IP address and a namespaced name.
type Addressable interface {
GetIPAddress() string
+ GetPort() string
+ GetMetricsHost() string
GetNamespacedName() types.NamespacedName
- GetRunningRequests() *RequestPriorityQueue
}
// PodInfo represents the relevant Kubernetes Pod state of an inference server.
type PodInfo struct {
- NamespacedName types.NamespacedName
- Address string
- Labels map[string]string
- RunningRequests *RequestPriorityQueue
-}
-
-// ToPodInfo converts a Kubernetes API Pod to its internal representation.
-func ToPodInfo(pod *corev1.Pod) *PodInfo {
- labels := make(map[string]string, len(pod.GetLabels()))
- for key, value := range pod.GetLabels() {
- labels[key] = value
- }
- return &PodInfo{
- NamespacedName: types.NamespacedName{
- Name: pod.Name,
- Namespace: pod.Namespace,
- },
- Address: pod.Status.PodIP,
- Labels: labels,
- RunningRequests: NewRequestPriorityQueue(),
- }
+ NamespacedName types.NamespacedName
+ PodName string
+ Address string
+ Port string
+ MetricsHost string
+ Labels map[string]string
}
// String returns a string representation of the pod.
@@ -73,18 +58,16 @@ func (p *PodInfo) Clone() *PodInfo {
for key, value := range p.Labels {
clonedLabels[key] = value
}
- var clonedRequests *RequestPriorityQueue
- if p.RunningRequests != nil {
- clonedRequests = p.RunningRequests.Clone()
- }
return &PodInfo{
NamespacedName: types.NamespacedName{
Name: p.NamespacedName.Name,
Namespace: p.NamespacedName.Namespace,
},
- Address: p.Address,
- Labels: clonedLabels,
- RunningRequests: clonedRequests,
+ PodName: p.PodName,
+ Address: p.Address,
+ Port: p.Port,
+ MetricsHost: p.MetricsHost,
+ Labels: clonedLabels,
}
}
@@ -98,7 +81,12 @@ func (p *PodInfo) GetIPAddress() string {
return p.Address
}
-// GetRunningRequests returns the running request queue for the Pod.
-func (p *PodInfo) GetRunningRequests() *RequestPriorityQueue {
- return p.RunningRequests
+// GetPort returns the Pod's inference port.
+func (p *PodInfo) GetPort() string {
+ return p.Port
+}
+
+// GetMetricsHost returns the pod's metrics host (ip:port)
+func (p *PodInfo) GetMetricsHost() string {
+ return p.MetricsHost
}
diff --git a/pkg/epp/datalayer/podinfo_test.go b/pkg/epp/datalayer/podinfo_test.go
index 91256cae7..baf804a22 100644
--- a/pkg/epp/datalayer/podinfo_test.go
+++ b/pkg/epp/datalayer/podinfo_test.go
@@ -55,17 +55,6 @@ var (
}
)
-func TestToPodInfo(t *testing.T) {
- podinfo := ToPodInfo(pod)
- if podinfo.RunningRequests == nil {
- t.Fatal("Expected RunningRequests to be initialized")
- }
- podinfo.RunningRequests = nil // Reset to nil for comparison, this is necessary because the podinfo is created with a new map each time
- if diff := cmp.Diff(expected, podinfo); diff != "" {
- t.Errorf("Unexpected output (-want +got): %v", diff)
- }
-}
-
func TestPodInfoClone(t *testing.T) {
clone := expected.Clone()
assert.NotSame(t, expected, clone)
@@ -78,7 +67,17 @@ func TestPodInfoClone(t *testing.T) {
}
func TestPodInfoString(t *testing.T) {
- podinfo := ToPodInfo(pod)
+ podinfo := PodInfo{
+ NamespacedName: types.NamespacedName{
+ Name: pod.Name,
+ Namespace: pod.Namespace,
+ },
+ PodName: pod.Name,
+ Address: pod.Status.PodIP,
+ Port: "8000",
+ MetricsHost: "127.0.0.1:8000",
+ Labels: labels,
+ }
s := podinfo.String()
assert.Contains(t, s, name)
diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go
index e2e9bebbc..5dcd0f4a0 100644
--- a/pkg/epp/datastore/datastore.go
+++ b/pkg/epp/datastore/datastore.go
@@ -20,7 +20,9 @@ import (
"context"
"errors"
"fmt"
+ "net"
"reflect"
+ "strconv"
"sync"
corev1 "k8s.io/api/core/v1"
@@ -33,7 +35,6 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
- dlmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/metrics"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod"
)
@@ -62,31 +63,20 @@ type Datastore interface {
// PodList lists pods matching the given predicate.
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool
- PodDelete(namespacedName types.NamespacedName)
-
- // Request management operations
- // PodAddRequest adds a request to a specific pod's running requests queue
- PodAddRequest(podName types.NamespacedName, requestID string, tpot float64) error
- // PodRemoveRequest removes a request from a specific pod's running requests queue
- PodRemoveRequest(podName types.NamespacedName, requestID string) error
- // PodUpdateRequest updates the TPOT value for a request in a specific pod's queue
- PodUpdateRequest(podName types.NamespacedName, requestID string, tpot float64) error
- // PodGetRunningRequests returns the priority queue for a specific pod
- PodGetRunningRequests(podName types.NamespacedName) (*datalayer.RequestPriorityQueue, error)
- // PodGetRequestCount returns the number of running requests for a specific pod
- PodGetRequestCount(podName types.NamespacedName) (int, error)
+ PodDelete(podName string)
// Clears the store state, happens when the pool gets deleted.
Clear()
}
-func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory) Datastore {
+func NewDatastore(parentCtx context.Context, epFactory datalayer.EndpointFactory, modelServerMetricsPort int32) Datastore {
store := &datastore{
- parentCtx: parentCtx,
- poolAndObjectivesMu: sync.RWMutex{},
- objectives: make(map[string]*v1alpha2.InferenceObjective),
- pods: &sync.Map{},
- epf: epFactory,
+ parentCtx: parentCtx,
+ poolAndObjectivesMu: sync.RWMutex{},
+ objectives: make(map[string]*v1alpha2.InferenceObjective),
+ pods: &sync.Map{},
+ modelServerMetricsPort: modelServerMetricsPort,
+ epf: epFactory,
}
return store
}
@@ -101,7 +91,10 @@ type datastore struct {
objectives map[string]*v1alpha2.InferenceObjective
// key: types.NamespacedName, value: backendmetrics.PodMetrics
pods *sync.Map
- epf datalayer.EndpointFactory
+ // modelServerMetricsPort metrics port from EPP command line argument
+ // used only if there is only one inference engine per pod
+ modelServerMetricsPort int32
+ epf datalayer.EndpointFactory
}
func (ds *datastore) Clear() {
@@ -129,11 +122,6 @@ func (ds *datastore) PoolSet(ctx context.Context, reader client.Reader, pool *v1
oldPool := ds.pool
ds.pool = pool
- if oldPool == nil || pool.Spec.TargetPorts[0] != oldPool.Spec.TargetPorts[0] {
- if source, found := datalayer.GetNamedSource[*dlmetrics.DataSource](dlmetrics.DataSourceName); found {
- source.SetPort(int32(pool.Spec.TargetPorts[0].Number))
- }
- }
if oldPool == nil || !reflect.DeepEqual(pool.Spec.Selector, oldPool.Spec.Selector) {
logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", pool.Spec.Selector)
// A full resync is required to address two cases:
@@ -227,126 +215,65 @@ func (ds *datastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []b
}
func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool {
- namespacedName := types.NamespacedName{
- Name: pod.Name,
- Namespace: pod.Namespace,
- }
- var pm backendmetrics.PodMetrics
- existing, ok := ds.pods.Load(namespacedName)
- if !ok {
- pm = ds.epf.NewEndpoint(ds.parentCtx, pod, ds)
- ds.pods.Store(namespacedName, pm)
- } else {
- pm = existing.(backendmetrics.PodMetrics)
- }
- // Update pod properties if anything changed.
- pm.UpdatePod(pod)
- return ok
-}
-
-func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
- v, ok := ds.pods.LoadAndDelete(namespacedName)
- if ok {
- ds.epf.ReleaseEndpoint(v.(backendmetrics.PodMetrics))
- }
-}
-
-// /// Request Management APIs ///
-
-func (ds *datastore) PodAddRequest(podName types.NamespacedName, requestID string, tpot float64) error {
- pm, ok := ds.pods.Load(podName)
- if !ok {
- return fmt.Errorf("pod %s not found in datastore", podName)
- }
-
- // TODO add to universal request map if needed for global tracking
-
- podMetrics := pm.(backendmetrics.PodMetrics)
- runningRequests := podMetrics.GetRunningRequests()
- if runningRequests == nil {
- return fmt.Errorf("pod %s does not have running requests queue initialized", podName)
- }
-
- // Request flow in datalayer
- //
- // Add request
-
- if !runningRequests.Add(requestID, tpot) {
- return fmt.Errorf("request %s already exists in pod %s", requestID, podName)
- }
-
- return nil
-}
-
-func (ds *datastore) PodRemoveRequest(podName types.NamespacedName, requestID string) error {
- pm, ok := ds.pods.Load(podName)
- if !ok {
- return fmt.Errorf("pod %s not found in datastore", podName)
- }
-
- // Request removal from universal request map if needed for global tracking
-
- podMetrics := pm.(backendmetrics.PodMetrics)
- runningRequests := podMetrics.GetRunningRequests()
- if runningRequests == nil {
- return fmt.Errorf("pod %s does not have running requests queue initialized", podName)
- }
-
- _, removed := runningRequests.Remove(requestID)
- if !removed {
- return fmt.Errorf("request %s not found in pod %s", requestID, podName)
- }
-
- return nil
-}
-
-func (ds *datastore) PodUpdateRequest(podName types.NamespacedName, requestID string, tpot float64) error {
- pm, ok := ds.pods.Load(podName)
- if !ok {
- return fmt.Errorf("pod %s not found in datastore", podName)
+ if ds.pool == nil {
+ return true
}
- podMetrics := pm.(backendmetrics.PodMetrics)
- runningRequests := podMetrics.GetRunningRequests()
- if runningRequests == nil {
- return fmt.Errorf("pod %s does not have running requests queue initialized", podName)
+ labels := make(map[string]string, len(pod.GetLabels()))
+ for key, value := range pod.GetLabels() {
+ labels[key] = value
}
- if !runningRequests.Update(requestID, tpot) {
- return fmt.Errorf("request %s not found in pod %s", requestID, podName)
+ modelServerMetricsPort := 0
+ if len(ds.pool.Spec.TargetPorts) == 1 {
+ modelServerMetricsPort = int(ds.modelServerMetricsPort)
}
-
- return nil
-}
-
-func (ds *datastore) PodGetRunningRequests(podName types.NamespacedName) (*datalayer.RequestPriorityQueue, error) {
- pm, ok := ds.pods.Load(podName)
- if !ok {
- return nil, fmt.Errorf("pod %s not found in datastore", podName)
+ pods := []*datalayer.PodInfo{}
+ for idx, port := range ds.pool.Spec.TargetPorts {
+ metricsPort := modelServerMetricsPort
+ if metricsPort == 0 {
+ metricsPort = int(port.Number)
+ }
+ pods = append(pods,
+ &datalayer.PodInfo{
+ NamespacedName: types.NamespacedName{
+ Name: pod.Name + "-rank-" + strconv.Itoa(idx),
+ Namespace: pod.Namespace,
+ },
+ PodName: pod.Name,
+ Address: pod.Status.PodIP,
+ Port: strconv.Itoa(int(port.Number)),
+ MetricsHost: net.JoinHostPort(pod.Status.PodIP, strconv.Itoa(metricsPort)),
+ Labels: labels,
+ })
}
- podMetrics := pm.(backendmetrics.PodMetrics)
- runningRequests := podMetrics.GetRunningRequests()
- if runningRequests == nil {
- return nil, fmt.Errorf("pod %s does not have running requests queue initialized", podName)
+ result := true
+ for _, podInfo := range pods {
+ var pm backendmetrics.PodMetrics
+ existing, ok := ds.pods.Load(podInfo.NamespacedName)
+ if !ok {
+ pm = ds.epf.NewEndpoint(ds.parentCtx, podInfo, ds)
+ ds.pods.Store(podInfo.NamespacedName, pm)
+ result = false
+ } else {
+ pm = existing.(backendmetrics.PodMetrics)
+ }
+ // Update pod properties if anything changed.
+ pm.UpdatePod(podInfo)
}
-
- return runningRequests, nil
+ return result
}
-func (ds *datastore) PodGetRequestCount(podName types.NamespacedName) (int, error) {
- pm, ok := ds.pods.Load(podName)
- if !ok {
- return 0, fmt.Errorf("pod %s not found in datastore", podName)
- }
-
- podMetrics := pm.(backendmetrics.PodMetrics)
- runningRequests := podMetrics.GetRunningRequests()
- if runningRequests == nil {
- return 0, fmt.Errorf("pod %s does not have running requests queue initialized", podName)
- }
-
- return runningRequests.GetSize(), nil
+func (ds *datastore) PodDelete(podName string) {
+ ds.pods.Range(func(k, v any) bool {
+ pm := v.(backendmetrics.PodMetrics)
+ if pm.GetPod().PodName == podName {
+ ds.pods.Delete(k)
+ ds.epf.ReleaseEndpoint(pm)
+ }
+ return true
+ })
}
func (ds *datastore) podResyncAll(ctx context.Context, reader client.Reader) error {
@@ -376,9 +303,9 @@ func (ds *datastore) podResyncAll(ctx context.Context, reader client.Reader) err
// Remove pods that don't belong to the pool or not ready any more.
ds.pods.Range(func(k, v any) bool {
pm := v.(backendmetrics.PodMetrics)
- if exist := activePods[pm.GetPod().NamespacedName.Name]; !exist {
+ if exist := activePods[pm.GetPod().PodName]; !exist {
logger.V(logutil.VERBOSE).Info("Removing pod", "pod", pm.GetPod())
- ds.PodDelete(pm.GetPod().NamespacedName)
+ ds.PodDelete(pm.GetPod().PodName)
}
return true
})
diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go
index 271c31ee7..ee59071e6 100644
--- a/pkg/epp/datastore/datastore_test.go
+++ b/pkg/epp/datastore/datastore_test.go
@@ -19,6 +19,8 @@ package datastore
import (
"context"
"errors"
+ "net"
+ "strconv"
"testing"
"time"
@@ -35,6 +37,7 @@ import (
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing"
)
@@ -83,21 +86,21 @@ func TestPool(t *testing.T) {
WithScheme(scheme).
Build()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- datastore := NewDatastore(context.Background(), pmf)
- _ = datastore.PoolSet(context.Background(), fakeClient, tt.inferencePool)
- gotPool, gotErr := datastore.PoolGet()
+ ds := NewDatastore(context.Background(), pmf, 0)
+ _ = ds.PoolSet(context.Background(), fakeClient, tt.inferencePool)
+ gotPool, gotErr := ds.PoolGet()
if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" {
t.Errorf("Unexpected error diff (+got/-want): %s", diff)
}
if diff := cmp.Diff(tt.wantPool, gotPool); diff != "" {
t.Errorf("Unexpected pool diff (+got/-want): %s", diff)
}
- gotSynced := datastore.PoolHasSynced()
+ gotSynced := ds.PoolHasSynced()
if diff := cmp.Diff(tt.wantSynced, gotSynced); diff != "" {
t.Errorf("Unexpected synced diff (+got/-want): %s", diff)
}
if tt.labels != nil {
- gotLabelsMatch := datastore.PoolLabelsMatch(tt.labels)
+ gotLabelsMatch := ds.PoolLabelsMatch(tt.labels)
if diff := cmp.Diff(tt.wantLabelsMatch, gotLabelsMatch); diff != "" {
t.Errorf("Unexpected labels match diff (+got/-want): %s", diff)
}
@@ -190,7 +193,7 @@ func TestObjective(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- ds := NewDatastore(t.Context(), pmf)
+ ds := NewDatastore(t.Context(), pmf, 0)
for _, m := range test.existingModels {
ds.ObjectiveSet(m)
}
@@ -241,13 +244,22 @@ var (
WaitingModels: map[string]int{},
}
- pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace}
- pod2NamespacedName = types.NamespacedName{Name: pod2.Name, Namespace: pod2.Namespace}
+ pod1NamespacedName = types.NamespacedName{Name: pod1.Name + "-rank-0", Namespace: pod1.Namespace}
+ pod2NamespacedName = types.NamespacedName{Name: pod2.Name + "-rank-0", Namespace: pod2.Namespace}
inferencePool = &v1.InferencePool{
Spec: v1.InferencePoolSpec{
TargetPorts: []v1.Port{{Number: v1.PortNumber(int32(8000))}},
},
}
+ inferencePoolMultiTarget = &v1.InferencePool{
+ Spec: v1.InferencePoolSpec{
+ TargetPorts: []v1.Port{{Number: v1.PortNumber(int32(8000))}, {Number: v1.PortNumber(int32(8001))}},
+ },
+ }
+
+ inferencePoolTargetPort = strconv.Itoa(int(inferencePool.Spec.TargetPorts[0].Number))
+ inferencePoolMultiTargetPort0 = strconv.Itoa(int(inferencePoolMultiTarget.Spec.TargetPorts[0].Number))
+ inferencePoolMultiTargetPort1 = strconv.Itoa(int(inferencePoolMultiTarget.Spec.TargetPorts[1].Number))
)
func TestMetrics(t *testing.T) {
@@ -315,7 +327,7 @@ func TestMetrics(t *testing.T) {
WithScheme(scheme).
Build()
pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond)
- ds := NewDatastore(ctx, pmf)
+ ds := NewDatastore(ctx, pmf, 0)
_ = ds.PoolSet(ctx, fakeClient, inferencePool)
for _, pod := range test.storePods {
ds.PodUpdateOrAddIfNotExist(pod)
@@ -340,14 +352,6 @@ func TestMetrics(t *testing.T) {
}
func TestPods(t *testing.T) {
- updatedPod := &corev1.Pod{
- ObjectMeta: metav1.ObjectMeta{
- Name: "pod1",
- },
- Spec: corev1.PodSpec{
- NodeName: "node-1",
- },
- }
tests := []struct {
name string
op func(ctx context.Context, ds Datastore)
@@ -371,60 +375,226 @@ func TestPods(t *testing.T) {
},
},
{
- name: "Update existing pod, new field, should update",
- existingPods: []*corev1.Pod{pod1},
- wantPods: []*corev1.Pod{updatedPod},
+ name: "Delete the pod",
+ existingPods: []*corev1.Pod{pod1, pod2},
+ wantPods: []*corev1.Pod{pod1},
op: func(ctx context.Context, ds Datastore) {
- ds.PodUpdateOrAddIfNotExist(updatedPod)
+ ds.PodDelete(pod2.Name)
},
},
{
- name: "Update existing pod, no new fields, should not update",
+ name: "Delete the pod that doesn't exist",
existingPods: []*corev1.Pod{pod1},
wantPods: []*corev1.Pod{pod1},
op: func(ctx context.Context, ds Datastore) {
- incoming := &corev1.Pod{
- ObjectMeta: metav1.ObjectMeta{
- Name: "pod1",
- Namespace: "default",
+ ds.PodDelete(pod2.Name)
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ctx := context.Background()
+ pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
+ ds := NewDatastore(t.Context(), pmf, 0)
+ fakeClient := fake.NewFakeClient()
+ if err := ds.PoolSet(ctx, fakeClient, inferencePool); err != nil {
+ t.Error(err)
+ }
+ for _, pod := range test.existingPods {
+ ds.PodUpdateOrAddIfNotExist(pod)
+ }
+
+ test.op(ctx, ds)
+ var gotPods []*corev1.Pod
+ for _, pm := range ds.PodList(backendmetrics.AllPodsPredicate) {
+ pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().PodName, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().GetIPAddress()}}
+ gotPods = append(gotPods, pod)
+ }
+ if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) {
+ t.Errorf("got (%v) != want (%v);", gotPods, test.wantPods)
+ }
+ })
+ }
+}
+
+func TestPodInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ op func(ctx context.Context, ds Datastore)
+ pool *v1.InferencePool
+ existingPods []*corev1.Pod
+ wantPodInfos []*datalayer.PodInfo
+ }{
+ {
+ name: "Add new pod, no existing pods, should add",
+ existingPods: []*corev1.Pod{},
+ wantPodInfos: []*datalayer.PodInfo{
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-0",
+ Namespace: pod1.Namespace,
},
- }
- ds.PodUpdateOrAddIfNotExist(incoming)
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolTargetPort,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolTargetPort),
+ Labels: map[string]string{},
+ },
+ },
+ op: func(ctx context.Context, ds Datastore) {
+ ds.PodUpdateOrAddIfNotExist(pod1)
},
+ pool: inferencePool,
},
{
- name: "Delete the pod",
- wantPods: []*corev1.Pod{pod1},
+ name: "Add new pod, no existing pods, should add, multiple target ports",
+ existingPods: []*corev1.Pod{},
+ wantPodInfos: []*datalayer.PodInfo{
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-0",
+ Namespace: pod1.Namespace,
+ },
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolMultiTargetPort0,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort0),
+ Labels: map[string]string{},
+ },
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-1",
+ Namespace: pod1.Namespace,
+ },
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolMultiTargetPort1,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort1),
+ Labels: map[string]string{},
+ },
+ },
op: func(ctx context.Context, ds Datastore) {
- ds.PodDelete(pod2NamespacedName)
+ ds.PodUpdateOrAddIfNotExist(pod1)
},
+ pool: inferencePoolMultiTarget,
},
{
- name: "Delete the pod that doesn't exist",
+ name: "Add new pod, with existing pods, should add, multiple target ports",
existingPods: []*corev1.Pod{pod1},
- wantPods: []*corev1.Pod{pod1},
+ wantPodInfos: []*datalayer.PodInfo{
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-0",
+ Namespace: pod1.Namespace,
+ },
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolMultiTargetPort0,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort0),
+ Labels: map[string]string{},
+ },
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-1",
+ Namespace: pod1.Namespace,
+ },
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolMultiTargetPort1,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort1),
+ Labels: map[string]string{},
+ },
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod2.Name + "-rank-0",
+ Namespace: pod2.Namespace,
+ },
+
+ PodName: pod2.Name,
+ Address: pod2.Status.PodIP,
+ Port: inferencePoolMultiTargetPort0,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort0),
+ Labels: map[string]string{},
+ },
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod2.Name + "-rank-1",
+ Namespace: pod2.Namespace,
+ },
+
+ PodName: pod2.Name,
+ Address: pod2.Status.PodIP,
+ Port: inferencePoolMultiTargetPort1,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort1),
+ Labels: map[string]string{},
+ },
+ },
+ op: func(ctx context.Context, ds Datastore) {
+ ds.PodUpdateOrAddIfNotExist(pod2)
+ },
+ pool: inferencePoolMultiTarget,
+ },
+ {
+ name: "Delete the pod, multiple target ports",
+ existingPods: []*corev1.Pod{pod1, pod2},
+ wantPodInfos: []*datalayer.PodInfo{
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-0",
+ Namespace: pod1.Namespace,
+ },
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolMultiTargetPort0,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort0),
+ Labels: map[string]string{},
+ },
+ {
+ NamespacedName: types.NamespacedName{
+ Name: pod1.Name + "-rank-1",
+ Namespace: pod1.Namespace,
+ },
+
+ PodName: pod1.Name,
+ Address: pod1.Status.PodIP,
+ Port: inferencePoolMultiTargetPort1,
+ MetricsHost: net.JoinHostPort(pod1.Status.PodIP, inferencePoolMultiTargetPort1),
+ Labels: map[string]string{},
+ },
+ },
op: func(ctx context.Context, ds Datastore) {
- ds.PodDelete(pod2NamespacedName)
+ ds.PodDelete(pod2.Name)
},
+ pool: inferencePoolMultiTarget,
},
}
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := context.Background()
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- ds := NewDatastore(t.Context(), pmf)
+ ds := NewDatastore(t.Context(), pmf, 0)
+ fakeClient := fake.NewFakeClient()
+ if err := ds.PoolSet(ctx, fakeClient, test.pool); err != nil {
+ t.Error(err)
+ }
for _, pod := range test.existingPods {
ds.PodUpdateOrAddIfNotExist(pod)
}
test.op(ctx, ds)
- var gotPods []*corev1.Pod
+ var gotPodInfos []*datalayer.PodInfo
for _, pm := range ds.PodList(backendmetrics.AllPodsPredicate) {
- pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}}
- gotPods = append(gotPods, pod)
+ gotPodInfos = append(gotPodInfos, pm.GetPod())
}
- if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) {
- t.Logf("got (%v) != want (%v);", gotPods, test.wantPods)
+ if diff := cmp.Diff(test.wantPodInfos, gotPodInfos, cmpopts.SortSlices(func(a, b *datalayer.PodInfo) bool { return a.NamespacedName.Name < b.NamespacedName.Name })); diff != "" {
+ t.Errorf("ConvertTo() mismatch (-want +got):\n%s", diff)
}
})
}
diff --git a/pkg/epp/flowcontrol/config.go b/pkg/epp/flowcontrol/config.go
new file mode 100644
index 000000000..edc23abad
--- /dev/null
+++ b/pkg/epp/flowcontrol/config.go
@@ -0,0 +1,50 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package flowcontrol
+
+import (
+ "fmt"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/controller"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/registry"
+)
+
+// Config is the top-level configuration for the entire flow control module.
+// It embeds the configurations for the controller and the registry, providing a single point of entry for validation
+// and initialization.
+type Config struct {
+ Controller controller.Config
+ Registry registry.Config
+}
+
+// ValidateAndApplyDefaults checks the configuration for validity and populates any empty fields with system defaults.
+// It delegates validation to the underlying controller and registry configurations.
+// It returns a new, validated `Config` object and does not mutate the receiver.
+func (c *Config) ValidateAndApplyDefaults() (*Config, error) {
+ validatedControllerCfg, err := c.Controller.ValidateAndApplyDefaults()
+ if err != nil {
+ return nil, fmt.Errorf("controller config validation failed: %w", err)
+ }
+ validatedRegistryCfg, err := c.Registry.ValidateAndApplyDefaults()
+ if err != nil {
+ return nil, fmt.Errorf("registry config validation failed: %w", err)
+ }
+ return &Config{
+ Controller: *validatedControllerCfg,
+ Registry: *validatedRegistryCfg,
+ }, nil
+}
diff --git a/pkg/epp/flowcontrol/config_test.go b/pkg/epp/flowcontrol/config_test.go
new file mode 100644
index 000000000..713abee77
--- /dev/null
+++ b/pkg/epp/flowcontrol/config_test.go
@@ -0,0 +1,91 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package flowcontrol
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/controller"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/registry"
+)
+
+func TestConfig_ValidateAndApplyDefaults(t *testing.T) {
+ t.Parallel()
+
+ // A minimal valid registry config, which is required for the success case.
+ validRegistryConfig := registry.Config{
+ PriorityBands: []registry.PriorityBandConfig{
+ {Priority: 1, PriorityName: "TestBand"},
+ },
+ }
+
+ testCases := []struct {
+ name string
+ input Config
+ expectErr bool
+ expectedErrIs error
+ }{
+ {
+ name: "ShouldSucceed_WhenSubConfigsAreValid",
+ input: Config{
+ Controller: controller.Config{},
+ Registry: validRegistryConfig,
+ },
+ expectErr: false,
+ },
+ {
+ name: "ShouldFail_WhenControllerConfigIsInvalid",
+ input: Config{
+ Controller: controller.Config{
+ DefaultRequestTTL: -1 * time.Second,
+ },
+ Registry: validRegistryConfig,
+ },
+ expectErr: true,
+ },
+ {
+ name: "ShouldFail_WhenRegistryConfigIsInvalid",
+ input: Config{
+ Controller: controller.Config{},
+ Registry: registry.Config{
+ PriorityBands: []registry.PriorityBandConfig{},
+ },
+ },
+ expectErr: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ originalInput := tc.input
+ validatedCfg, err := tc.input.ValidateAndApplyDefaults()
+
+ if tc.expectErr {
+ require.Error(t, err, "expected an error but got nil")
+ } else {
+ require.NoError(t, err, "expected no error but got: %v", err)
+ require.NotNil(t, validatedCfg, "validatedCfg should not be nil on success")
+ }
+
+ assert.Equal(t, originalInput, tc.input, "input config should not be mutated")
+ })
+ }
+}
diff --git a/pkg/epp/flowcontrol/contracts/mocks/mocks.go b/pkg/epp/flowcontrol/contracts/mocks/mocks.go
index c5c8d2e3b..10f093b11 100644
--- a/pkg/epp/flowcontrol/contracts/mocks/mocks.go
+++ b/pkg/epp/flowcontrol/contracts/mocks/mocks.go
@@ -34,6 +34,7 @@ import (
"fmt"
"sync"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
@@ -48,9 +49,9 @@ type MockRegistryShard struct {
IsActiveFunc func() bool
ManagedQueueFunc func(key types.FlowKey) (contracts.ManagedQueue, error)
IntraFlowDispatchPolicyFunc func(key types.FlowKey) (framework.IntraFlowDispatchPolicy, error)
- InterFlowDispatchPolicyFunc func(priority uint) (framework.InterFlowDispatchPolicy, error)
- PriorityBandAccessorFunc func(priority uint) (framework.PriorityBandAccessor, error)
- AllOrderedPriorityLevelsFunc func() []uint
+ InterFlowDispatchPolicyFunc func(priority int) (framework.InterFlowDispatchPolicy, error)
+ PriorityBandAccessorFunc func(priority int) (framework.PriorityBandAccessor, error)
+ AllOrderedPriorityLevelsFunc func() []int
StatsFunc func() contracts.ShardStats
}
@@ -82,21 +83,21 @@ func (m *MockRegistryShard) IntraFlowDispatchPolicy(key types.FlowKey) (framewor
return nil, nil
}
-func (m *MockRegistryShard) InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error) {
+func (m *MockRegistryShard) InterFlowDispatchPolicy(priority int) (framework.InterFlowDispatchPolicy, error) {
if m.InterFlowDispatchPolicyFunc != nil {
return m.InterFlowDispatchPolicyFunc(priority)
}
return nil, nil
}
-func (m *MockRegistryShard) PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error) {
+func (m *MockRegistryShard) PriorityBandAccessor(priority int) (framework.PriorityBandAccessor, error) {
if m.PriorityBandAccessorFunc != nil {
return m.PriorityBandAccessorFunc(priority)
}
return nil, nil
}
-func (m *MockRegistryShard) AllOrderedPriorityLevels() []uint {
+func (m *MockRegistryShard) AllOrderedPriorityLevels() []int {
if m.AllOrderedPriorityLevelsFunc != nil {
return m.AllOrderedPriorityLevelsFunc()
}
@@ -112,12 +113,12 @@ func (m *MockRegistryShard) Stats() contracts.ShardStats {
// MockSaturationDetector is a simple "stub-style" mock for testing.
type MockSaturationDetector struct {
- IsSaturatedFunc func(ctx context.Context) bool
+ IsSaturatedFunc func(ctx context.Context, candidatePods []metrics.PodMetrics) bool
}
-func (m *MockSaturationDetector) IsSaturated(ctx context.Context) bool {
+func (m *MockSaturationDetector) IsSaturated(ctx context.Context, candidatePods []metrics.PodMetrics) bool {
if m.IsSaturatedFunc != nil {
- return m.IsSaturatedFunc(ctx)
+ return m.IsSaturatedFunc(ctx, candidatePods)
}
return false
}
diff --git a/pkg/epp/flowcontrol/contracts/registry.go b/pkg/epp/flowcontrol/contracts/registry.go
index de1b89ae6..fe0b790b9 100644
--- a/pkg/epp/flowcontrol/contracts/registry.go
+++ b/pkg/epp/flowcontrol/contracts/registry.go
@@ -22,8 +22,8 @@ import (
)
// FlowRegistry is the complete interface for the global flow control plane.
-// It composes the client-facing data path interface and the administrative interface. A concrete implementation of this
-// interface is the single source of truth for all flow control state.
+// It composes all role-based interfaces. A concrete implementation of this interface is the single source of truth for
+// all flow control state.
//
// # Conformance: Implementations MUST be goroutine-safe.
//
@@ -48,22 +48,21 @@ import (
// 2. Capacity Partitioning: Global and per-band capacity limits must be uniformly partitioned across all Active
// shards.
type FlowRegistry interface {
- FlowRegistryClient
- FlowRegistryAdmin
+ FlowRegistryObserver
+ FlowRegistryDataPlane
}
-// FlowRegistryAdmin defines the administrative interface for the global control plane.
-type FlowRegistryAdmin interface {
- // Stats returns globally aggregated statistics for the entire `FlowRegistry`.
+// FlowRegistryObserver defines the read-only, observation interface for the registry.
+type FlowRegistryObserver interface {
+ // Stats returns a near-consistent snapshot globally aggregated statistics for the entire `FlowRegistry`.
Stats() AggregateStats
- // ShardStats returns a slice of statistics, one for each internal shard.
+ // ShardStats returns a near-consistent slice of statistics snapshots, one for each `RegistryShard`.
ShardStats() []ShardStats
}
-// FlowRegistryClient defines the primary, client-facing interface for the registry.
-// This is the interface that the `controller.FlowController`'s data path depends upon.
-type FlowRegistryClient interface {
+// FlowRegistryDataPlane defines the high-throughput, request-path interface for the registry.
+type FlowRegistryDataPlane interface {
// WithConnection manages a scoped, leased session for a given flow.
// It is the primary and sole entry point for interacting with the data path.
//
@@ -90,9 +89,8 @@ type FlowRegistryClient interface {
// Its purpose is to ensure that any interaction with the flow's state (e.g., accessing its shards and queues) occurs
// safely while the flow is guaranteed to be protected from garbage collection.
type ActiveFlowConnection interface {
- // Shards returns a stable snapshot of accessors for all internal state shards (both Active and Draining).
- // Consumers MUST check `RegistryShard.IsActive()` before routing new work to a shard from this slice.
- Shards() []RegistryShard
+ // ActiveShards returns a stable snapshot of accessors for all Active internal state shards.
+ ActiveShards() []RegistryShard
}
// RegistryShard defines the interface for a single slice (shard) of the `FlowRegistry`'s state.
@@ -124,22 +122,22 @@ type RegistryShard interface {
// InterFlowDispatchPolicy retrieves a priority band's configured `framework.InterFlowDispatchPolicy` for this shard.
// The registry guarantees that a non-nil default policy is returned if none is configured for the band.
// Returns an error wrapping `ErrPriorityBandNotFound` if the priority level is not configured.
- InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error)
+ InterFlowDispatchPolicy(priority int) (framework.InterFlowDispatchPolicy, error)
// PriorityBandAccessor retrieves a read-only accessor for a given priority level, providing a view of the band's
// state as seen by this specific shard. This is the primary entry point for inter-flow dispatch policies that need to
// inspect and compare multiple flow queues within the same priority band.
// Returns an error wrapping `ErrPriorityBandNotFound` if the priority level is not configured.
- PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error)
+ PriorityBandAccessor(priority int) (framework.PriorityBandAccessor, error)
- // AllOrderedPriorityLevels returns all configured priority levels that this shard is aware of, sorted in ascending
- // numerical order. This order corresponds to highest priority (lowest numeric value) to lowest priority (highest
+ // AllOrderedPriorityLevels returns all configured priority levels that this shard is aware of, sorted in descending
+ // numerical order. This order corresponds to highest priority (highest numeric value) to lowest priority (lowest
// numeric value).
// The returned slice provides a definitive, ordered list of priority levels for iteration, for example, by a
// `controller.FlowController` worker's dispatch loop.
- AllOrderedPriorityLevels() []uint
+ AllOrderedPriorityLevels() []int
- // Stats returns a snapshot of the statistics for this specific shard.
+ // Stats returns a near consistent snapshot of the shard's state.
Stats() ShardStats
}
@@ -162,6 +160,7 @@ type ManagedQueue interface {
}
// AggregateStats holds globally aggregated statistics for the entire `FlowRegistry`.
+// It is a read-only data object representing a near-consistent snapshot of the registry's state.
type AggregateStats struct {
// TotalCapacityBytes is the globally configured maximum total byte size limit across all priority bands and shards.
TotalCapacityBytes uint64
@@ -170,11 +169,18 @@ type AggregateStats struct {
// TotalLen is the total number of items currently queued across the entire system.
TotalLen uint64
// PerPriorityBandStats maps each configured priority level to its globally aggregated statistics.
- PerPriorityBandStats map[uint]PriorityBandStats
+ PerPriorityBandStats map[int]PriorityBandStats
}
-// ShardStats holds statistics for a single internal shard within the `FlowRegistry`.
+// ShardStats holds statistics and identifying information for a `RegistryShard` within the `FlowRegistry`.
+// It is a read-only data object representing a near-consistent snapshot of the shard's state.
type ShardStats struct {
+ // ID is the unique, stable identifier for this shard.
+ ID string
+ // IsActive indicates if the shard was accepting new work at the time this stats snapshot was generated.
+ // A value of `false` means the shard is in the process of being gracefully drained.
+ // Due to the concurrent nature of the system, this state could change immediately after the snapshot is taken.
+ IsActive bool
// TotalCapacityBytes is the optional, maximum total byte size limit aggregated across all priority bands within this
// shard. Its value represents the globally configured limit for the `FlowRegistry` partitioned for this shard.
// The `controller.FlowController` enforces this limit in addition to any per-band capacity limits.
@@ -188,13 +194,14 @@ type ShardStats struct {
// The capacity values within represent this shard's partition of the global band capacity.
// The key is the numerical priority level.
// All configured priority levels are guaranteed to be represented.
- PerPriorityBandStats map[uint]PriorityBandStats
+ PerPriorityBandStats map[int]PriorityBandStats
}
// PriorityBandStats holds aggregated statistics for a single priority band.
+// It is a read-only data object representing a near-consistent snapshot of the priority band's state.
type PriorityBandStats struct {
// Priority is the numerical priority level this struct describes.
- Priority uint
+ Priority int
// PriorityName is a human-readable name for the priority band (e.g., "Critical", "Sheddable").
// The registry configuration requires this field, so it is guaranteed to be non-empty.
PriorityName string
diff --git a/pkg/epp/flowcontrol/contracts/saturationdetector.go b/pkg/epp/flowcontrol/contracts/saturationdetector.go
index 91d2406c5..15037d50a 100644
--- a/pkg/epp/flowcontrol/contracts/saturationdetector.go
+++ b/pkg/epp/flowcontrol/contracts/saturationdetector.go
@@ -16,7 +16,11 @@ limitations under the License.
package contracts
-import "context"
+import (
+ "context"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
+)
// SaturationDetector defines the contract for a component that provides real-time load signals to the
// `controller.FlowController`.
@@ -32,8 +36,8 @@ import "context"
//
// Implementations MUST be goroutine-safe.
type SaturationDetector interface {
- // IsSaturated returns true if the system's backend resources are considered saturated.
+ // IsSaturated returns true if the system's backend resources are considered saturated for a set of candidate pods.
// `controller.FlowController`'s dispatch workers call this method to decide whether to pause or throttle dispatch
// operations to prevent overwhelming the backends.
- IsSaturated(ctx context.Context) bool
+ IsSaturated(ctx context.Context, candidatePods []metrics.PodMetrics) bool
}
diff --git a/pkg/epp/flowcontrol/controller/config.go b/pkg/epp/flowcontrol/controller/config.go
new file mode 100644
index 000000000..e542c4d6b
--- /dev/null
+++ b/pkg/epp/flowcontrol/controller/config.go
@@ -0,0 +1,102 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package controller
+
+import (
+ "fmt"
+ "time"
+)
+
+const (
+ // defaultExpiryCleanupInterval is the default frequency for scanning for expired items.
+ defaultExpiryCleanupInterval = 1 * time.Second
+ // defaultProcessorReconciliationInterval is the default frequency for the supervisor loop.
+ defaultProcessorReconciliationInterval = 5 * time.Second
+ // defaultEnqueueChannelBufferSize is the default size of a worker's incoming request buffer.
+ defaultEnqueueChannelBufferSize = 100
+)
+
+// Config holds the configuration for the `FlowController`.
+type Config struct {
+ // DefaultRequestTTL is the default Time-To-Live applied to requests that do not
+ // specify their own TTL hint.
+ // Optional: If zero, no TTL is applied by default and we rely solely on request context cancellation.
+ DefaultRequestTTL time.Duration
+
+ // ExpiryCleanupInterval is the interval at which each shard processor scans its queues for expired items.
+ // Optional: Defaults to `defaultExpiryCleanupInterval` (1 second).
+ ExpiryCleanupInterval time.Duration
+
+ // ProcessorReconciliationInterval is the frequency at which the `FlowController`'s supervisor loop garbage collects
+ // stale workers.
+ // Optional: Defaults to `defaultProcessorReconciliationInterval` (5 seconds).
+ ProcessorReconciliationInterval time.Duration
+
+ // EnqueueChannelBufferSize is the size of the buffered channel that accepts incoming requests for each shard
+ // processor. This buffer acts as a shock absorber, decoupling the high-frequency distributor from the processor's
+ // serial execution loop and allowing the system to handle short bursts of traffic without blocking.
+ // Optional: Defaults to `defaultEnqueueChannelBufferSize` (100).
+ EnqueueChannelBufferSize int
+}
+
+// ValidateAndApplyDefaults checks the global configuration for validity and then creates a new `Config` object,
+// populating any empty fields with system defaults.
+// It does not mutate the receiver.
+func (c *Config) ValidateAndApplyDefaults() (*Config, error) {
+ cfg := c.deepCopy()
+
+ // --- Validation ---
+ if cfg.DefaultRequestTTL < 0 {
+ return nil, fmt.Errorf("DefaultRequestTTL cannot be negative, but got %v", cfg.DefaultRequestTTL)
+ }
+ if cfg.ExpiryCleanupInterval < 0 {
+ return nil, fmt.Errorf("ExpiryCleanupInterval cannot be negative, but got %v", cfg.ExpiryCleanupInterval)
+ }
+ if cfg.ProcessorReconciliationInterval < 0 {
+ return nil, fmt.Errorf("ProcessorReconciliationInterval cannot be negative, but got %v",
+ cfg.ProcessorReconciliationInterval)
+ }
+ if cfg.EnqueueChannelBufferSize < 0 {
+ return nil, fmt.Errorf("EnqueueChannelBufferSize cannot be negative, but got %d", cfg.EnqueueChannelBufferSize)
+ }
+
+ // --- Defaulting ---
+ if cfg.ExpiryCleanupInterval == 0 {
+ cfg.ExpiryCleanupInterval = defaultExpiryCleanupInterval
+ }
+ if cfg.ProcessorReconciliationInterval == 0 {
+ cfg.ProcessorReconciliationInterval = defaultProcessorReconciliationInterval
+ }
+ if cfg.EnqueueChannelBufferSize == 0 {
+ cfg.EnqueueChannelBufferSize = defaultEnqueueChannelBufferSize
+ }
+ return cfg, nil
+}
+
+// deepCopy creates a deep copy of the `Config` object.
+func (c *Config) deepCopy() *Config {
+ if c == nil {
+ return nil
+ }
+ newCfg := &Config{
+ DefaultRequestTTL: c.DefaultRequestTTL,
+ ExpiryCleanupInterval: c.ExpiryCleanupInterval,
+ ProcessorReconciliationInterval: c.ProcessorReconciliationInterval,
+ EnqueueChannelBufferSize: c.EnqueueChannelBufferSize,
+ }
+ return newCfg
+}
diff --git a/pkg/epp/flowcontrol/controller/config_test.go b/pkg/epp/flowcontrol/controller/config_test.go
new file mode 100644
index 000000000..710df9fa7
--- /dev/null
+++ b/pkg/epp/flowcontrol/controller/config_test.go
@@ -0,0 +1,135 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package controller
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestConfig_ValidateAndApplyDefaults(t *testing.T) {
+ t.Parallel()
+
+ testCases := []struct {
+ name string
+ input Config
+ expectErr bool
+ expectedCfg Config
+ shouldDefault bool
+ }{
+ {
+ name: "ValidConfig_NoChanges",
+ input: Config{
+ DefaultRequestTTL: 10 * time.Second,
+ ExpiryCleanupInterval: 2 * time.Second,
+ ProcessorReconciliationInterval: 10 * time.Second,
+ EnqueueChannelBufferSize: 200,
+ },
+ expectErr: false,
+ expectedCfg: Config{
+ DefaultRequestTTL: 10 * time.Second,
+ ExpiryCleanupInterval: 2 * time.Second,
+ ProcessorReconciliationInterval: 10 * time.Second,
+ EnqueueChannelBufferSize: 200,
+ },
+ },
+ {
+ name: "EmptyConfig_ShouldApplyDefaults",
+ input: Config{},
+ expectErr: false,
+ expectedCfg: Config{
+ DefaultRequestTTL: 0,
+ ExpiryCleanupInterval: defaultExpiryCleanupInterval,
+ ProcessorReconciliationInterval: defaultProcessorReconciliationInterval,
+ EnqueueChannelBufferSize: defaultEnqueueChannelBufferSize,
+ },
+ shouldDefault: true,
+ },
+ {
+ name: "NegativeDefaultRequestTTL_Invalid",
+ input: Config{DefaultRequestTTL: -1},
+ expectErr: true,
+ },
+ {
+ name: "NegativeExpiryCleanupInterval_Invalid",
+ input: Config{ExpiryCleanupInterval: -1},
+ expectErr: true,
+ },
+ {
+ name: "NegativeProcessorReconciliationInterval_Invalid",
+ input: Config{ProcessorReconciliationInterval: -1},
+ expectErr: true,
+ },
+ {
+ name: "NegativeEnqueueChannelBufferSize_Invalid",
+ input: Config{EnqueueChannelBufferSize: -1},
+ expectErr: true,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ originalInput := tc.input.deepCopy()
+ validatedCfg, err := tc.input.ValidateAndApplyDefaults()
+
+ if tc.expectErr {
+ require.Error(t, err, "expected an error but got nil")
+ assert.Nil(t, validatedCfg, "validatedCfg should be nil on error")
+ } else {
+ require.NoError(t, err, "expected no error but got: %v", err)
+ require.NotNil(t, validatedCfg, "validatedCfg should not be nil on success")
+ assert.Equal(t, tc.expectedCfg, *validatedCfg, "validatedCfg should match expected config")
+
+ // Ensure the original config is not mutated.
+ assert.Equal(t, *originalInput, tc.input, "input config should not be mutated")
+ }
+ })
+ }
+}
+
+func TestConfig_DeepCopy(t *testing.T) {
+ t.Parallel()
+
+ t.Run("ShouldReturnNil_ForNilReceiver", func(t *testing.T) {
+ t.Parallel()
+ var nilConfig *Config
+ assert.Nil(t, nilConfig.deepCopy(), "Deep copy of a nil config should be nil")
+ })
+
+ t.Run("ShouldCreateIdenticalButSeparateObject", func(t *testing.T) {
+ t.Parallel()
+ original := &Config{
+ DefaultRequestTTL: 1 * time.Second,
+ ExpiryCleanupInterval: 2 * time.Second,
+ ProcessorReconciliationInterval: 3 * time.Second,
+ EnqueueChannelBufferSize: 4,
+ }
+ clone := original.deepCopy()
+
+ require.NotSame(t, original, clone, "Clone should be a new object in memory")
+ assert.Equal(t, *original, *clone, "Cloned object should have identical values")
+
+ // Modify the clone and ensure the original is unchanged.
+ clone.DefaultRequestTTL = 99 * time.Second
+ assert.NotEqual(t, original.DefaultRequestTTL, clone.DefaultRequestTTL,
+ "Original should not be mutated after clone is changed")
+ })
+}
diff --git a/pkg/epp/flowcontrol/controller/controller.go b/pkg/epp/flowcontrol/controller/controller.go
new file mode 100644
index 000000000..a11ef26a5
--- /dev/null
+++ b/pkg/epp/flowcontrol/controller/controller.go
@@ -0,0 +1,516 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+// Package controller contains the implementation of the FlowController engine.
+//
+// The FlowController is the central processing engine of the Flow Control layer. It is a sharded, high-throughput
+// component responsible for managing the lifecycle of all incoming requests. It achieves this by acting as a stateless
+// supervisor that orchestrates a pool of stateful workers (ShardProcessors), distributing incoming requests among them.
+package controller
+
+import (
+ "cmp"
+ "context"
+ "errors"
+ "fmt"
+ "slices"
+ "strconv"
+ "sync"
+ "time"
+
+ "github.com/go-logr/logr"
+ k8srand "k8s.io/apimachinery/pkg/util/rand"
+ "k8s.io/utils/clock"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/controller/internal"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+)
+
+// registryClient defines the minimal interface that the FlowController needs to interact with the FlowRegistry.
+type registryClient interface {
+ contracts.FlowRegistryObserver
+ contracts.FlowRegistryDataPlane
+}
+
+// shardProcessor is the minimal internal interface that the FlowController requires from its workers.
+type shardProcessor interface {
+ Run(ctx context.Context)
+ Submit(item *internal.FlowItem) error
+ SubmitOrBlock(ctx context.Context, item *internal.FlowItem) error
+}
+
+// shardProcessorFactory defines the signature for creating a shardProcessor.
+type shardProcessorFactory func(
+ ctx context.Context,
+ shard contracts.RegistryShard,
+ saturationDetector contracts.SaturationDetector,
+ clock clock.WithTicker,
+ cleanupSweepInterval time.Duration,
+ enqueueChannelBufferSize int,
+ logger logr.Logger,
+) shardProcessor
+
+var _ shardProcessor = &internal.ShardProcessor{}
+
+// managedWorker holds the state for a single supervised worker.
+type managedWorker struct {
+ processor shardProcessor
+ // cancel function for the worker-specific context. Used during shutdown and GC.
+ cancel context.CancelFunc
+}
+
+// FlowController is the central, high-throughput engine of the Flow Control layer.
+// It is designed as a stateless distributor that orchestrates a pool of stateful workers (ShardProcessor), following a
+// supervisor-worker pattern.
+//
+// The controller's run loop executes periodically, acting as a garbage collector that keeps the pool of running
+// workers synchronized with the dynamic shard topology of the FlowRegistry.
+//
+// Request Lifecycle Management:
+//
+// 1. Asynchronous Finalization (Controller-Owned): The Controller actively monitors the request Context
+// (TTL/Cancellation) in EnqueueAndWait. If the Context expires, the Controller immediately Finalizes the item and
+// unblocks the caller.
+// 2. Synchronous Finalization (Processor-Owned): The Processor handles Dispatch, Capacity Rejection, and Shutdown.
+// 3. Cleanup (Processor-Owned): The Processor periodically sweeps externally finalized items to reclaim capacity.
+type FlowController struct {
+ // --- Immutable dependencies (set at construction) ---
+
+ config Config
+ registry registryClient
+ saturationDetector contracts.SaturationDetector
+ clock clock.WithTicker
+ logger logr.Logger
+ shardProcessorFactory shardProcessorFactory
+
+ // --- Lifecycle state ---
+
+ // parentCtx is the root context for the controller's lifecycle, established when NewFlowController is called.
+ // It is the parent for all long-lived worker goroutines.
+ parentCtx context.Context
+
+ // --- Concurrent state ---
+
+ // workers is a highly concurrent map storing the managedWorker for each shard.
+ // It is the controller's source of truth for the worker pool.
+ workers sync.Map // key: shard ID (string); value: *managedWorker
+
+ // wg waits for all worker goroutines to terminate during shutdown.
+ wg sync.WaitGroup
+}
+
+// flowControllerOption is a function that applies a configuration change.
+// test-only
+type flowControllerOption func(*FlowController)
+
+// NewFlowController creates and starts a new FlowController instance.
+// The provided context governs the lifecycle of the controller and all its workers.
+func NewFlowController(
+ ctx context.Context,
+ config Config,
+ registry contracts.FlowRegistry,
+ sd contracts.SaturationDetector,
+ logger logr.Logger,
+ opts ...flowControllerOption,
+) (*FlowController, error) {
+ fc := &FlowController{
+ config: config,
+ registry: registry,
+ saturationDetector: sd,
+ clock: clock.RealClock{},
+ logger: logger.WithName("flow-controller"),
+ parentCtx: ctx,
+ }
+
+ fc.shardProcessorFactory = func(
+ ctx context.Context,
+ shard contracts.RegistryShard,
+ saturationDetector contracts.SaturationDetector,
+ clock clock.WithTicker,
+ cleanupSweepInterval time.Duration,
+ enqueueChannelBufferSize int,
+ logger logr.Logger,
+ ) shardProcessor {
+ return internal.NewShardProcessor(
+ ctx,
+ shard,
+ saturationDetector,
+ clock,
+ cleanupSweepInterval,
+ enqueueChannelBufferSize,
+ logger)
+ }
+
+ for _, opt := range opts {
+ opt(fc)
+ }
+
+ go fc.run(ctx)
+ return fc, nil
+}
+
+// run starts the FlowController's main reconciliation loop (supervisor loop).
+// This loop is responsible for garbage collecting workers whose shards no longer exist in the registry.
+// This method blocks until the provided context is cancelled and all worker goroutines have fully terminated.
+func (fc *FlowController) run(ctx context.Context) {
+ fc.logger.Info("Starting FlowController reconciliation loop.")
+ defer fc.logger.Info("FlowController reconciliation loop stopped.")
+
+ ticker := fc.clock.NewTicker(fc.config.ProcessorReconciliationInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ fc.shutdown()
+ return
+ case <-ticker.C():
+ fc.reconcileProcessors()
+ }
+ }
+}
+
+// EnqueueAndWait is the primary, synchronous entry point to the Flow Control system. It submits a request and blocks
+// until the request reaches a terminal outcome (dispatched, rejected, or evicted).
+//
+// # Design Rationale: The Synchronous Model
+//
+// This blocking model is deliberately chosen for its simplicity and robustness, especially in the context of Envoy
+// External Processing (ext_proc), which operates on a stream-based protocol.
+//
+// - ext_proc Alignment: A single goroutine typically manages the stream for a given HTTP request.
+// EnqueueAndWait fits this perfectly: the request-handling goroutine calls it, blocks, and upon return, has a
+// definitive outcome to act upon.
+// - Simplified State Management: The state of a "waiting" request is implicitly managed by the blocked goroutine's
+// stack and its Context. The system only needs to signal this specific goroutine to unblock it.
+// - Direct Backpressure: If queues are full, EnqueueAndWait returns an error immediately, providing direct
+// backpressure to the caller.
+func (fc *FlowController) EnqueueAndWait(
+ ctx context.Context,
+ req types.FlowControlRequest,
+) (types.QueueOutcome, error) {
+ flowKey := req.FlowKey()
+ fairnessID := flowKey.ID
+ priority := strconv.Itoa(flowKey.Priority)
+ metrics.IncFlowControlQueueSize(fairnessID, priority)
+ defer metrics.DecFlowControlQueueSize(fairnessID, priority)
+
+ // 1. Create the derived context that governs this request's lifecycle (Parent Cancellation + TTL).
+ reqCtx, cancel, enqueueTime := fc.createRequestContext(ctx, req)
+ defer cancel()
+
+ // 2. Enter the distribution loop to find a home for the request.
+ // This loop is responsible for retrying on ErrShardDraining.
+ for {
+ select { // Non-blocking check on controller lifecycle.
+ case <-fc.parentCtx.Done():
+ return types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerNotRunning)
+ default:
+ }
+
+ // Attempt to distribute the request once.
+ item, err := fc.tryDistribution(reqCtx, req, enqueueTime)
+ if err != nil {
+ // Distribution failed terminally (e.g., no shards, context cancelled during blocking submit).
+ // The item has already been finalized by tryDistribution.
+ finalState := item.FinalState()
+ return finalState.Outcome, finalState.Err
+ }
+
+ // Distribution was successful; ownership of the item has been transferred to a processor.
+ // Now, we block here in awaitFinalization until the request is finalized by either the processor (e.g., dispatched,
+ // rejected) or the controller itself (e.g., caller's context cancelled/TTL expired).
+ outcome, err := fc.awaitFinalization(reqCtx, item)
+ if errors.Is(err, contracts.ErrShardDraining) {
+ // This is a benign race condition where the chosen shard started draining after acceptance.
+ fc.logger.V(logutil.DEBUG).Info("Selected shard is Draining, retrying request distribution",
+ "flowKey", req.FlowKey(), "requestID", req.ID())
+ // Introduce a small, randomized delay (1-10ms) to prevent tight spinning loops and thundering herds during retry
+ // scenarios (e.g., shard draining)
+ // TODO: Replace this with a more sophisticated backoff strategy when our data parallelism story matures.
+ // For now, this is more than sufficient.
+ jitterMs := k8srand.Intn(10) + 1
+ fc.clock.Sleep(time.Duration(jitterMs) * time.Millisecond)
+ continue
+ }
+
+ // The outcome is terminal (Dispatched, Evicted, or a non-retriable rejection).
+ return outcome, err
+ }
+}
+
+var errNoShards = errors.New("no viable active shards available")
+
+// tryDistribution handles a single attempt to select a shard and submit a request.
+// If this function returns an error, it guarantees that the provided `item` has been finalized.
+func (fc *FlowController) tryDistribution(
+ reqCtx context.Context,
+ req types.FlowControlRequest,
+ enqueueTime time.Time,
+) (*internal.FlowItem, error) {
+ // Calculate effective TTL for item initialization (reqCtx is the enforcement mechanism).
+ effectiveTTL := fc.config.DefaultRequestTTL
+ if deadline, ok := reqCtx.Deadline(); ok {
+ if ttl := deadline.Sub(enqueueTime); ttl > 0 {
+ effectiveTTL = ttl
+ }
+ }
+
+ // We must create a fresh FlowItem on each attempt as finalization is per-lifecycle.
+ item := internal.NewItem(req, effectiveTTL, enqueueTime)
+
+ candidates, err := fc.selectDistributionCandidates(item.OriginalRequest().FlowKey())
+ if err != nil {
+ outcome := types.QueueOutcomeRejectedOther
+ if errors.Is(err, errNoShards) {
+ outcome = types.QueueOutcomeRejectedCapacity
+ }
+ finalErr := fmt.Errorf("%w: request not accepted: %w", types.ErrRejected, err)
+ item.FinalizeWithOutcome(outcome, finalErr)
+ return item, finalErr
+ }
+
+ outcome, err := fc.distributeRequest(reqCtx, item, candidates)
+ if err == nil {
+ // Success: Ownership of the item has been transferred to the processor.
+ return item, nil
+ }
+
+ // For any distribution error, the controller retains ownership and must finalize the item.
+ var finalErr error
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ // We propagate the original context error here, EnqueueAndWait will rely on item.FinalState().Err.
+ finalErr = err
+ item.Finalize(context.Cause(reqCtx))
+ } else { // e.g.,
+ finalErr = fmt.Errorf("%w: request not accepted: %w", types.ErrRejected, err)
+ item.FinalizeWithOutcome(outcome, finalErr)
+ }
+ return item, finalErr
+}
+
+// awaitFinalization blocks until an item is finalized, either by the processor (synchronously) or by the controller
+// itself due to context expiry (asynchronously).
+func (fc *FlowController) awaitFinalization(
+ reqCtx context.Context,
+ item *internal.FlowItem,
+) (types.QueueOutcome, error) {
+ select {
+ case <-reqCtx.Done():
+ // Asynchronous Finalization (Controller-initiated):
+ // The request Context expired (Cancellation/TTL) while the item was being processed.
+ cause := context.Cause(reqCtx)
+ item.Finalize(cause)
+
+ // The processor will eventually discard this "zombie" item during its cleanup sweep.
+ finalState := item.FinalState()
+ return finalState.Outcome, finalState.Err
+
+ case finalState := <-item.Done():
+ // Synchronous Finalization (Processor-initiated):
+ // The processor finalized the item (Dispatch, Reject, Shutdown).
+ return finalState.Outcome, finalState.Err
+ }
+}
+
+// createRequestContext derives the context that governs a request's lifecycle, enforcing the TTL deadline.
+func (fc *FlowController) createRequestContext(
+ ctx context.Context,
+ req types.FlowControlRequest,
+) (context.Context, context.CancelFunc, time.Time) {
+ enqueueTime := fc.clock.Now()
+ effectiveTTL := req.InitialEffectiveTTL()
+ if effectiveTTL <= 0 {
+ effectiveTTL = fc.config.DefaultRequestTTL
+ }
+
+ if effectiveTTL > 0 {
+ reqCtx, cancel := context.WithDeadlineCause(ctx, enqueueTime.Add(effectiveTTL), types.ErrTTLExpired)
+ return reqCtx, cancel, enqueueTime
+ }
+ reqCtx, cancel := context.WithCancel(ctx)
+ return reqCtx, cancel, enqueueTime
+}
+
+// candidate holds the information needed to evaluate a shard as a potential target for a request.
+type candidate struct {
+ processor shardProcessor
+ shardID string
+ byteSize uint64
+}
+
+// selectDistributionCandidates identifies all Active shards for the item's flow and ranks them by the current byte size
+// of that flow's queue, from least to most loaded.
+func (fc *FlowController) selectDistributionCandidates(key types.FlowKey) ([]candidate, error) {
+ var candidates []candidate
+
+ // Acquire a connection to the registry for the flow key. This ensures a consistent view of the ActiveShards for the
+ // duration of the shard selection process, preventing races with concurrent shard topology changes.
+ err := fc.registry.WithConnection(key, func(conn contracts.ActiveFlowConnection) error {
+ shards := conn.ActiveShards()
+ candidates = make([]candidate, 0, len(shards))
+ for _, shard := range shards {
+ worker := fc.getOrStartWorker(shard)
+ mq, err := shard.ManagedQueue(key)
+ if err != nil {
+ fc.logger.Error(err,
+ "Invariant violation. Failed to get ManagedQueue for a leased flow on an Active shard. Skipping shard.",
+ "flowKey", key, "shardID", shard.ID())
+ continue
+ }
+ candidates = append(candidates, candidate{worker.processor, shard.ID(), mq.ByteSize()})
+ }
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed to acquire lease for flow %s: %w", key, err)
+ }
+
+ if len(candidates) == 0 {
+ return nil, fmt.Errorf("%w for flow %s", errNoShards, key)
+ }
+
+ slices.SortFunc(candidates, func(a, b candidate) int {
+ return cmp.Compare(a.byteSize, b.byteSize)
+ })
+
+ return candidates, nil
+}
+
+// distributeRequest implements a flow-aware, two-phase "Join-Shortest-Queue-by-Bytes" (JSQ-Bytes) distribution strategy
+// with graceful backpressure. It attempts to submit an item to the best-ranked candidate from the provided list.
+//
+// The algorithm operates as follows:
+// 1. Phase 1 (Non-blocking Fast Failover): It iterates through the ranked candidates and attempts a non-blocking
+// submission. The first successful submission wins.
+// 2. Phase 2 (Blocking Fallback): If all non-blocking attempts fail, it performs a single blocking submission to the
+// least-loaded candidate, providing backpressure.
+//
+// The provided context (ctx) is used for the blocking submission phase (SubmitOrBlock).
+//
+// Ownership Contract:
+// - Returns nil: Success. Ownership transferred to Processor.
+// - Returns error: Failure (Context expiry, shutdown,, etc.).
+// Ownership retained by Controller. The Controller MUST finalize the item.
+func (fc *FlowController) distributeRequest(
+ ctx context.Context,
+ item *internal.FlowItem,
+ candidates []candidate,
+) (types.QueueOutcome, error) {
+ reqID := item.OriginalRequest().ID()
+ for _, c := range candidates {
+ if err := c.processor.Submit(item); err == nil {
+ return types.QueueOutcomeNotYetFinalized, nil
+ }
+ fc.logger.V(logutil.TRACE).Info("Processor busy during fast failover, trying next candidate",
+ "shardID", c.shardID, "requestID", reqID)
+ }
+
+ // All processors are busy. Attempt a single blocking submission to the least-loaded candidate.
+ bestCandidate := candidates[0]
+ fc.logger.V(logutil.TRACE).Info("All processors busy, attempting blocking submit to best candidate",
+ "shardID", bestCandidate.shardID, "requestID", reqID)
+ err := bestCandidate.processor.SubmitOrBlock(ctx, item)
+ if err != nil {
+ return types.QueueOutcomeRejectedOther, fmt.Errorf("%w: request not accepted: %w", types.ErrRejected, err)
+ }
+ return types.QueueOutcomeNotYetFinalized, nil // Success, ownership transferred.
+}
+
+// getOrStartWorker implements the lazy-loading and startup of shard processors.
+// It ensures that exactly one worker goroutine is started for each shard, using atomic operations
+// (sync.Map.LoadOrStore). The worker's processor goroutine is only started after it has successfully been registered,
+// preventing race conditions where multiple goroutines create and start the same worker.
+func (fc *FlowController) getOrStartWorker(shard contracts.RegistryShard) *managedWorker {
+ if w, ok := fc.workers.Load(shard.ID()); ok {
+ return w.(*managedWorker)
+ }
+
+ // Construct a new worker, but do not start its goroutine yet.
+ processorCtx, cancel := context.WithCancel(fc.parentCtx)
+ processor := fc.shardProcessorFactory(
+ processorCtx,
+ shard,
+ fc.saturationDetector,
+ fc.clock,
+ fc.config.ExpiryCleanupInterval,
+ fc.config.EnqueueChannelBufferSize,
+ fc.logger.WithValues("shardID", shard.ID()),
+ )
+ newWorker := &managedWorker{
+ processor: processor,
+ cancel: cancel,
+ }
+
+ // Atomically load or store. This is the critical synchronization step.
+ actual, loaded := fc.workers.LoadOrStore(shard.ID(), newWorker)
+ if loaded {
+ // Another goroutine beat us to it. The `newWorker` we created was not stored.
+ // We must cancel the context we created to prevent a leak.
+ cancel()
+ return actual.(*managedWorker)
+ }
+
+ // We won the race. The newWorker was stored. Now, start the processor's long-running goroutine.
+ fc.logger.V(logutil.DEFAULT).Info("Starting new ShardProcessor worker.", "shardID", shard.ID())
+ fc.wg.Add(1)
+ go func() {
+ defer fc.wg.Done()
+ processor.Run(processorCtx)
+ }()
+
+ return newWorker
+}
+
+// reconcileProcessors is the supervisor's core garbage collection loop.
+// It identifies and stops workers whose corresponding shards have been removed from the registry.
+func (fc *FlowController) reconcileProcessors() {
+ stats := fc.registry.ShardStats()
+ shards := make(map[string]struct{}, len(stats)) // map[shardID] -> isActive
+ for _, s := range stats {
+ shards[s.ID] = struct{}{}
+ }
+
+ fc.workers.Range(func(key, value any) bool {
+ shardID := key.(string)
+ worker := value.(*managedWorker)
+ if _, exists := shards[shardID]; !exists {
+ fc.logger.V(logutil.DEFAULT).Info("Stale worker detected for GC'd shard, initiating shutdown.",
+ "shardID", shardID)
+ worker.cancel() // Cancel the worker's context, initiating the Processor's graceful shutdown sequence.
+ fc.workers.Delete(shardID) // Delete from the map so no new requests are routed to it.
+ }
+ return true
+ })
+}
+
+// shutdown gracefully terminates all running `shardProcessor` goroutines.
+// It signals all workers to stop and waits for them to complete their shutdown procedures.
+func (fc *FlowController) shutdown() {
+ fc.logger.Info("Shutting down FlowController and all shard processors.")
+ fc.workers.Range(func(key, value any) bool {
+ shardID := key.(string)
+ worker := value.(*managedWorker)
+ fc.logger.V(logutil.VERBOSE).Info("Sending shutdown signal to processor", "shardID", shardID)
+ worker.cancel()
+ return true
+ })
+ fc.wg.Wait()
+ fc.logger.Info("All shard processors have shut down.")
+}
diff --git a/pkg/epp/flowcontrol/controller/controller_test.go b/pkg/epp/flowcontrol/controller/controller_test.go
new file mode 100644
index 000000000..917a5a1e5
--- /dev/null
+++ b/pkg/epp/flowcontrol/controller/controller_test.go
@@ -0,0 +1,1303 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+// Note on Time-Based Lifecycle Tests:
+// Tests validating the controller's handling of request TTLs (e.g., OnReqCtxTimeout*) rely on real-time timers
+// (context.WithDeadline). The injected testclock.FakeClock is used to control the timing of internal loops (like
+// reconciliation), but it cannot manipulate the timers used by the standard context package. Therefore, these specific
+// tests use time.Sleep or assertions on real-time durations.
+
+package controller
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/go-logr/logr"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "k8s.io/utils/clock"
+ testclock "k8s.io/utils/clock/testing"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts/mocks"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/controller/internal"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework"
+ frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
+ typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks"
+)
+
+// --- Test Harness & Fixtures ---
+
+// withClock returns a test-only option to inject a clock.
+// test-only
+func withClock(c clock.WithTicker) flowControllerOption {
+ return func(fc *FlowController) {
+ fc.clock = c
+ }
+}
+
+// withRegistryClient returns a test-only option to inject a mock or fake registry client.
+// test-only
+func withRegistryClient(client registryClient) flowControllerOption {
+ return func(fc *FlowController) {
+ fc.registry = client
+ }
+}
+
+// withShardProcessorFactory returns a test-only option to inject a processor factory.
+// test-only
+func withShardProcessorFactory(factory shardProcessorFactory) flowControllerOption {
+ return func(fc *FlowController) {
+ fc.shardProcessorFactory = factory
+ }
+}
+
+// testHarness holds the `FlowController` and its dependencies under test.
+type testHarness struct {
+ fc *FlowController
+ cfg Config
+ // clock is the clock interface used by the controller.
+ clock clock.WithTicker
+ mockRegistry *mockRegistryClient
+ mockDetector *mocks.MockSaturationDetector
+ // mockClock provides access to FakeClock methods (Step, HasWaiters) if and only if the underlying clock is a
+ // FakeClock.
+ mockClock *testclock.FakeClock
+ mockProcessorFactory *mockShardProcessorFactory
+}
+
+// newUnitHarness creates a test environment with a mock processor factory, suitable for focused unit tests of the
+// controller's logic. It starts the controller's run loop using the provided context for lifecycle management.
+func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *mockRegistryClient) *testHarness {
+ t.Helper()
+ mockDetector := &mocks.MockSaturationDetector{}
+
+ // Initialize the FakeClock with the current system time.
+ // The controller implementation uses the injected clock to calculate the deadline timestamp,vbut uses the standard
+ // context.WithDeadline (which relies on the system clock) to enforce it.
+ // If the FakeClock's time is far from the system time, deadlines calculated based on the FakeClockvmight already be
+ // expired according to the system clock, causing immediate TTL failures.
+ mockClock := testclock.NewFakeClock(time.Now())
+
+ mockProcessorFactory := &mockShardProcessorFactory{
+ processors: make(map[string]*mockShardProcessor),
+ }
+
+ // Default the registry if nil, simplifying tests that don't focus on registry interaction.
+ if registry == nil {
+ registry = &mockRegistryClient{}
+ }
+
+ opts := []flowControllerOption{
+ withRegistryClient(registry),
+ withClock(mockClock),
+ withShardProcessorFactory(mockProcessorFactory.new),
+ }
+ fc, err := NewFlowController(ctx, cfg, registry, mockDetector, logr.Discard(), opts...)
+ require.NoError(t, err, "failed to create FlowController for unit test harness")
+
+ h := &testHarness{
+ fc: fc,
+ cfg: cfg,
+ clock: mockClock,
+ mockRegistry: registry,
+ mockDetector: mockDetector,
+ mockClock: mockClock,
+ mockProcessorFactory: mockProcessorFactory,
+ }
+ return h
+}
+
+// newIntegrationHarness creates a test environment that uses real `ShardProcessor`s, suitable for integration tests
+// validating the controller-processor interaction.
+func newIntegrationHarness(t *testing.T, ctx context.Context, cfg Config, registry *mockRegistryClient) *testHarness {
+ t.Helper()
+ mockDetector := &mocks.MockSaturationDetector{}
+ // Align FakeClock with system time. See explanation in newUnitHarness.
+
+ mockClock := testclock.NewFakeClock(time.Now())
+ if registry == nil {
+ registry = &mockRegistryClient{}
+ }
+
+ opts := []flowControllerOption{
+ withRegistryClient(registry),
+ withClock(mockClock),
+ }
+ fc, err := NewFlowController(ctx, cfg, registry, mockDetector, logr.Discard(), opts...)
+ require.NoError(t, err, "failed to create FlowController for integration test harness")
+
+ h := &testHarness{
+ fc: fc,
+ cfg: cfg,
+ clock: mockClock,
+ mockRegistry: registry,
+ mockDetector: mockDetector,
+ mockClock: mockClock,
+ }
+ return h
+}
+
+// mockActiveFlowConnection is a local mock for the `contracts.ActiveFlowConnection` interface.
+type mockActiveFlowConnection struct {
+ contracts.ActiveFlowConnection
+ ActiveShardsV []contracts.RegistryShard
+}
+
+func (m *mockActiveFlowConnection) ActiveShards() []contracts.RegistryShard {
+ return m.ActiveShardsV
+}
+
+// mockRegistryClient is a mock for the private `registryClient` interface.
+type mockRegistryClient struct {
+ contracts.FlowRegistryObserver
+ contracts.FlowRegistryDataPlane
+ WithConnectionFunc func(key types.FlowKey, fn func(conn contracts.ActiveFlowConnection) error) error
+ ShardStatsFunc func() []contracts.ShardStats
+}
+
+func (m *mockRegistryClient) WithConnection(
+ key types.FlowKey,
+ fn func(conn contracts.ActiveFlowConnection) error,
+) error {
+ if m.WithConnectionFunc != nil {
+ return m.WithConnectionFunc(key, fn)
+ }
+ return fn(&mockActiveFlowConnection{})
+}
+
+func (m *mockRegistryClient) ShardStats() []contracts.ShardStats {
+ if m.ShardStatsFunc != nil {
+ return m.ShardStatsFunc()
+ }
+ return nil
+}
+
+// mockShardProcessor is a mock for the internal `shardProcessor` interface.
+type mockShardProcessor struct {
+ SubmitFunc func(item *internal.FlowItem) error
+ SubmitOrBlockFunc func(ctx context.Context, item *internal.FlowItem) error
+ // runCtx captures the context provided to the Run method for lifecycle assertions.
+ runCtx context.Context
+ runCtxMu sync.RWMutex
+ // runStarted is closed when the Run method is called, allowing tests to synchronize with worker startup.
+ runStarted chan struct{}
+}
+
+func (m *mockShardProcessor) Submit(item *internal.FlowItem) error {
+ if m.SubmitFunc != nil {
+ return m.SubmitFunc(item)
+ }
+ return nil
+}
+
+func (m *mockShardProcessor) SubmitOrBlock(ctx context.Context, item *internal.FlowItem) error {
+ if m.SubmitOrBlockFunc != nil {
+ return m.SubmitOrBlockFunc(ctx, item)
+ }
+ return nil
+}
+
+func (m *mockShardProcessor) Run(ctx context.Context) {
+ m.runCtxMu.Lock()
+ m.runCtx = ctx
+ m.runCtxMu.Unlock()
+ if m.runStarted != nil {
+ close(m.runStarted)
+ }
+ // Block until the context is cancelled, simulating a running worker.
+ <-ctx.Done()
+}
+
+// Context returns the context captured during the Run method call.
+func (m *mockShardProcessor) Context() context.Context {
+ m.runCtxMu.RLock()
+ defer m.runCtxMu.RUnlock()
+ return m.runCtx
+}
+
+// mockShardProcessorFactory allows tests to inject specific `mockShardProcessor` instances.
+type mockShardProcessorFactory struct {
+ mu sync.Mutex
+ processors map[string]*mockShardProcessor
+}
+
+// new is the factory function conforming to the `shardProcessorFactory` signature.
+func (f *mockShardProcessorFactory) new(
+ _ context.Context, // The factory does not use the lifecycle context; it's passed to the processor's Run method later.
+ shard contracts.RegistryShard,
+ _ contracts.SaturationDetector,
+ _ clock.WithTicker,
+ _ time.Duration,
+ _ int,
+ _ logr.Logger,
+) shardProcessor {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ if proc, ok := f.processors[shard.ID()]; ok {
+ return proc
+ }
+ // Return a default mock processor if one is not explicitly registered by the test.
+ return &mockShardProcessor{}
+}
+
+// stubManagedQueue is a simple stub for the `contracts.ManagedQueue` interface.
+type stubManagedQueue struct {
+ contracts.ManagedQueue
+ byteSizeV uint64
+}
+
+func (s *stubManagedQueue) ByteSize() uint64 { return s.byteSizeV }
+
+func (s *stubManagedQueue) FlowQueueAccessor() framework.FlowQueueAccessor {
+ return &frameworkmocks.MockFlowQueueAccessor{ByteSizeV: s.byteSizeV}
+}
+
+// mockShardBuilder is a fixture to declaratively build mock `contracts.RegistryShard` for tests.
+type mockShardBuilder struct {
+ id string
+ byteSize uint64
+}
+
+func newMockShard(id string) *mockShardBuilder {
+ return &mockShardBuilder{id: id}
+}
+
+func (b *mockShardBuilder) withByteSize(size uint64) *mockShardBuilder {
+ b.byteSize = size
+ return b
+}
+
+func (b *mockShardBuilder) build() contracts.RegistryShard {
+ return &mocks.MockRegistryShard{
+ IDFunc: func() string { return b.id },
+ ManagedQueueFunc: func(_ types.FlowKey) (contracts.ManagedQueue, error) {
+ return &stubManagedQueue{byteSizeV: b.byteSize}, nil
+ },
+ }
+}
+
+var defaultFlowKey = types.FlowKey{ID: "test-flow", Priority: 100}
+
+func newTestRequest(key types.FlowKey) *typesmocks.MockFlowControlRequest {
+ return &typesmocks.MockFlowControlRequest{
+ FlowKeyV: key,
+ ByteSizeV: 100,
+ IDV: "req-" + key.ID,
+ }
+}
+
+// --- Test Cases ---
+
+// TestFlowController_EnqueueAndWait covers the primary API entry point, focusing on validation, distribution logic,
+// retries, and the request lifecycle (including post-distribution cancellation/timeout).
+func TestFlowController_EnqueueAndWait(t *testing.T) {
+ t.Parallel()
+
+ t.Run("Rejections", func(t *testing.T) {
+ t.Parallel()
+
+ t.Run("OnReqCtxExpiredBeforeDistribution", func(t *testing.T) {
+ t.Parallel()
+ // Test that if the request context provided to EnqueueAndWait is already expired, it returns immediately.
+ h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 1 * time.Minute}, nil)
+
+ // Configure registry to return a shard.
+ shardA := newMockShard("shard-A").build()
+ h.mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(_ contracts.ActiveFlowConnection) error) error {
+ return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA}})
+ }
+ // Configure processor to block until context expiry.
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy },
+ SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error {
+ <-ctx.Done() // Wait for the context to be done.
+ return context.Cause(ctx) // Return the cause.
+ },
+ }
+
+ req := newTestRequest(defaultFlowKey)
+ // Use a context with a deadline in the past.
+ reqCtx, cancel := context.WithDeadlineCause(
+ context.Background(),
+ h.clock.Now().Add(-1*time.Second),
+ types.ErrTTLExpired)
+ defer cancel()
+
+ outcome, err := h.fc.EnqueueAndWait(reqCtx, req)
+ require.Error(t, err, "EnqueueAndWait must fail if request context deadline is exceeded")
+ assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected")
+ assert.ErrorIs(t, err, types.ErrTTLExpired, "error should wrap types.ErrTTLExpired from the context cause")
+ assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "outcome should be QueueOutcomeRejectedOther")
+ })
+ t.Run("OnControllerShutdown", func(t *testing.T) {
+ t.Parallel()
+ // Create a context specifically for the controller's lifecycle.
+ ctx, cancel := context.WithCancel(t.Context())
+ h := newUnitHarness(t, ctx, Config{}, nil)
+ cancel() // Immediately stop the controller.
+
+ // Wait for the controller's run loop and all workers (none in this case) to exit.
+ // We need to wait because the shutdown process is asynchronous.
+ h.fc.wg.Wait()
+
+ req := newTestRequest(defaultFlowKey)
+ // The request context is valid, but the controller itself is stopped.
+ outcome, err := h.fc.EnqueueAndWait(context.Background(), req)
+ require.Error(t, err, "EnqueueAndWait must reject requests if controller is not running")
+ assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected")
+ assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, "error should wrap ErrFlowControllerNotRunning")
+ assert.Equal(t, types.QueueOutcomeRejectedOther, outcome,
+ "outcome should be QueueOutcomeRejectedOther on shutdown")
+ })
+
+ t.Run("OnNoShardsAvailable", func(t *testing.T) {
+ t.Parallel()
+ // The default mockRegistryClient returns an empty list of ActiveShards.
+ h := newUnitHarness(t, t.Context(), Config{}, nil)
+
+ req := newTestRequest(defaultFlowKey)
+ outcome, err := h.fc.EnqueueAndWait(context.Background(), req)
+ require.Error(t, err, "EnqueueAndWait must reject requests if no shards are available")
+ assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected")
+ assert.Equal(t, types.QueueOutcomeRejectedCapacity, outcome,
+ "outcome should be QueueOutcomeRejectedCapacity when no shards exist for the flow")
+ })
+
+ t.Run("OnRegistryConnectionError", func(t *testing.T) {
+ t.Parallel()
+ mockRegistry := &mockRegistryClient{}
+ h := newUnitHarness(t, t.Context(), Config{}, mockRegistry)
+
+ expectedErr := errors.New("simulated connection failure")
+ // Configure the registry to fail when attempting to retrieve ActiveFlowConnection.
+ mockRegistry.WithConnectionFunc = func(
+ _ types.FlowKey,
+ _ func(conn contracts.ActiveFlowConnection) error,
+ ) error {
+ return expectedErr
+ }
+
+ req := newTestRequest(defaultFlowKey)
+ outcome, err := h.fc.EnqueueAndWait(context.Background(), req)
+ require.Error(t, err, "EnqueueAndWait must reject requests if registry connection fails")
+ assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected")
+ assert.ErrorIs(t, err, expectedErr, "error should wrap the underlying connection error")
+ assert.Equal(t, types.QueueOutcomeRejectedOther, outcome,
+ "outcome should be QueueOutcomeRejectedOther for transient registry errors")
+ })
+
+ t.Run("OnManagedQueueError", func(t *testing.T) {
+ t.Parallel()
+ mockRegistry := &mockRegistryClient{}
+ h := newUnitHarness(t, t.Context(), Config{}, mockRegistry)
+
+ // Create a faulty shard that successfully leases the flow but fails to return the
+ // ManagedQueue. This shard should be considered as unavailable.
+ faultyShard := &mocks.MockRegistryShard{
+ IDFunc: func() string { return "faulty-shard" },
+ ManagedQueueFunc: func(_ types.FlowKey) (contracts.ManagedQueue, error) {
+ return nil, errors.New("invariant violation: queue retrieval failed")
+ },
+ }
+ mockRegistry.WithConnectionFunc = func(
+ _ types.FlowKey,
+ fn func(conn contracts.ActiveFlowConnection) error,
+ ) error {
+ return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{faultyShard}})
+ }
+
+ req := newTestRequest(defaultFlowKey)
+ outcome, err := h.fc.EnqueueAndWait(context.Background(), req)
+ require.Error(t, err, "EnqueueAndWait must reject requests if no shards are available")
+ assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected")
+ assert.Equal(t, types.QueueOutcomeRejectedCapacity, outcome,
+ "outcome should be QueueOutcomeRejectedCapacity when no shards exist for the flow")
+ })
+ })
+
+ // Distribution tests validate the JSQ-Bytes algorithm, the two-phase submission strategy, and error handling during
+ // the handoff, including time-based failures during blocking fallback.
+ t.Run("Distribution", func(t *testing.T) {
+ t.Parallel()
+
+ // Define a long default TTL to prevent unexpected timeouts unless a test case explicitly sets a shorter one.
+ const defaultTestTTL = 5 * time.Second
+
+ testCases := []struct {
+ name string
+ shards []contracts.RegistryShard
+ setupProcessors func(t *testing.T, h *testHarness)
+ // requestTTL overrides the default TTL for time-sensitive tests.
+ requestTTL time.Duration
+ expectedOutcome types.QueueOutcome
+ expectErr bool
+ expectErrIs error
+ }{
+ {
+ name: "SubmitSucceeds_NonBlocking_WithSingleActiveShard",
+ shards: []contracts.RegistryShard{newMockShard("shard-A").build()},
+ setupProcessors: func(t *testing.T, h *testHarness) {
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(item *internal.FlowItem) error {
+ // Simulate asynchronous processing and successful dispatch.
+ go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ return nil
+ },
+ }
+ },
+ expectedOutcome: types.QueueOutcomeDispatched,
+ },
+ {
+ name: "DistributesToLeastLoadedShard_WithMultipleActiveShards",
+ shards: []contracts.RegistryShard{
+ newMockShard("shard-A").withByteSize(1000).build(), // More loaded
+ newMockShard("shard-B").withByteSize(100).build(), // Least loaded
+ },
+ setupProcessors: func(t *testing.T, h *testHarness) {
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(_ *internal.FlowItem) error {
+ t.Error("Submit was called on the more loaded shard (shard-A); JSQ-Bytes algorithm failed")
+ return internal.ErrProcessorBusy
+ },
+ }
+ h.mockProcessorFactory.processors["shard-B"] = &mockShardProcessor{
+ SubmitFunc: func(item *internal.FlowItem) error {
+ item.SetHandle(&typesmocks.MockQueueItemHandle{})
+ go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ return nil
+ },
+ }
+ },
+ expectedOutcome: types.QueueOutcomeDispatched,
+ },
+ {
+ name: "SubmitSucceeds_AfterBlocking_WithAllProcessorsBusy",
+ shards: []contracts.RegistryShard{
+ newMockShard("shard-A").withByteSize(1000).build(),
+ newMockShard("shard-B").withByteSize(100).build(),
+ },
+ setupProcessors: func(t *testing.T, h *testHarness) {
+ // Both processors reject the initial non-blocking Submit.
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy },
+ }
+ // Shard-B is the least loaded, so it should receive the blocking fallback (SubmitOrBlock).
+ h.mockProcessorFactory.processors["shard-B"] = &mockShardProcessor{
+ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy },
+ SubmitOrBlockFunc: func(_ context.Context, item *internal.FlowItem) error {
+ // The blocking call succeeds.
+ go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ return nil
+ },
+ }
+ },
+ expectedOutcome: types.QueueOutcomeDispatched,
+ },
+ {
+ // Validates the scenario where the request's TTL expires while the controller is blocked waiting for capacity.
+ // NOTE: This relies on real time passing, as context.WithDeadline timers cannot be controlled by FakeClock.
+ name: "Rejects_AfterBlocking_WhenTTL_Expires",
+ shards: []contracts.RegistryShard{newMockShard("shard-A").build()},
+ requestTTL: 50 * time.Millisecond, // Short TTL to keep the test fast.
+ setupProcessors: func(t *testing.T, h *testHarness) {
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ // Reject the non-blocking attempt.
+ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy },
+ // Block the fallback attempt until the context (carrying the TTL deadline) expires.
+ SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error {
+ <-ctx.Done()
+ return ctx.Err()
+ },
+ }
+ },
+ // No runActions needed; we rely on the real-time timer to expire.
+ // When the blocking call fails due to context expiry, the outcome is RejectedOther.
+ expectedOutcome: types.QueueOutcomeRejectedOther,
+ expectErr: true,
+ // The error must reflect the specific cause of the context cancellation (ErrTTLExpired).
+ expectErrIs: types.ErrTTLExpired,
+ },
+ {
+ name: "Rejects_OnProcessorShutdownDuringSubmit",
+ shards: []contracts.RegistryShard{newMockShard("shard-A").build()},
+ setupProcessors: func(t *testing.T, h *testHarness) {
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ // Simulate the processor shutting down during the non-blocking handoff.
+ SubmitFunc: func(_ *internal.FlowItem) error { return types.ErrFlowControllerNotRunning },
+ SubmitOrBlockFunc: func(_ context.Context, _ *internal.FlowItem) error {
+ return types.ErrFlowControllerNotRunning
+ },
+ }
+ },
+ expectedOutcome: types.QueueOutcomeRejectedOther,
+ expectErr: true,
+ expectErrIs: types.ErrFlowControllerNotRunning,
+ },
+ {
+ name: "Rejects_OnProcessorShutdownDuringSubmitOrBlock",
+ shards: []contracts.RegistryShard{newMockShard("shard-A").build()},
+ setupProcessors: func(t *testing.T, h *testHarness) {
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy },
+ // Simulate the processor shutting down during the blocking handoff.
+ SubmitOrBlockFunc: func(_ context.Context, _ *internal.FlowItem) error {
+ return types.ErrFlowControllerNotRunning
+ },
+ }
+ },
+ expectedOutcome: types.QueueOutcomeRejectedOther,
+ expectErr: true,
+ expectErrIs: types.ErrFlowControllerNotRunning,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+
+ // Arrange
+ mockRegistry := &mockRegistryClient{}
+
+ // Configure the harness with the appropriate TTL.
+ harnessConfig := Config{DefaultRequestTTL: defaultTestTTL}
+ if tc.requestTTL > 0 {
+ harnessConfig.DefaultRequestTTL = tc.requestTTL
+ }
+ h := newUnitHarness(t, t.Context(), harnessConfig, mockRegistry)
+
+ // Configure the registry to return the specified shards.
+ mockRegistry.WithConnectionFunc = func(
+ _ types.FlowKey,
+ fn func(conn contracts.ActiveFlowConnection) error,
+ ) error {
+ return fn(&mockActiveFlowConnection{ActiveShardsV: tc.shards})
+ }
+ tc.setupProcessors(t, h)
+
+ // Act
+ var outcome types.QueueOutcome
+ var err error
+
+ startTime := time.Now() // Capture real start time for duration checks.
+ // Use a background context for the parent; the request lifecycle is governed by the config/derived context.
+ outcome, err = h.fc.EnqueueAndWait(context.Background(), newTestRequest(defaultFlowKey))
+
+ // Assert
+ if tc.expectErr {
+ require.Error(t, err, "expected an error during EnqueueAndWait but got nil")
+ assert.ErrorIs(t, err, tc.expectErrIs, "error should wrap the expected underlying cause")
+ // All failures during the distribution phase (capacity, timeout, shutdown) should result in a rejection.
+ assert.ErrorIs(t, err, types.ErrRejected, "rejection errors must wrap types.ErrRejected")
+
+ // Specific assertion for real-time TTL tests.
+ if errors.Is(tc.expectErrIs, types.ErrTTLExpired) {
+ duration := time.Since(startTime)
+ // Ensure the test didn't return instantly. Use a tolerance for CI environments.
+ // This validates that the real-time wait actually occurred.
+ assert.GreaterOrEqual(t, duration, tc.requestTTL-30*time.Millisecond,
+ "EnqueueAndWait returned faster than the TTL allows, indicating the timer did not function correctly")
+ }
+
+ } else {
+ require.NoError(t, err, "expected no error during EnqueueAndWait but got: %v", err)
+ }
+ assert.Equal(t, tc.expectedOutcome, outcome, "outcome did not match expected value")
+ })
+ }
+ })
+
+ t.Run("Retry", func(t *testing.T) {
+ t.Parallel()
+
+ // This test specifically validates the behavior when the request context is cancelled externally while the
+ // controller is blocked in the SubmitOrBlock phase.
+ t.Run("Rejects_OnRequestContextCancelledWhileBlocking", func(t *testing.T) {
+ t.Parallel()
+ mockRegistry := &mockRegistryClient{
+ WithConnectionFunc: func(
+ _ types.FlowKey,
+ fn func(conn contracts.ActiveFlowConnection,
+ ) error) error {
+ return fn(&mockActiveFlowConnection{
+ ActiveShardsV: []contracts.RegistryShard{newMockShard("shard-A").build()},
+ })
+ },
+ }
+ // Use a long TTL to ensure the failure is due to cancellation, not timeout.
+ h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 10 * time.Second}, mockRegistry)
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ // Reject non-blocking attempt.
+ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy },
+ // Block the fallback attempt until the context is cancelled.
+ SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error {
+ <-ctx.Done()
+ return ctx.Err()
+ },
+ }
+
+ // Create a cancellable context for the request.
+ reqCtx, cancelReq := context.WithCancel(context.Background())
+ // Cancel the request shortly after starting the operation.
+ // We use real time sleep here as we are testing external cancellation signals interacting with the context.
+ go func() { time.Sleep(10 * time.Millisecond); cancelReq() }()
+
+ outcome, err := h.fc.EnqueueAndWait(reqCtx, newTestRequest(defaultFlowKey))
+
+ require.Error(t, err, "EnqueueAndWait must fail when context is cancelled during a blocking submit")
+ assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected")
+ assert.ErrorIs(t, err, context.Canceled, "error should wrap the underlying ctx.Err() (context.Canceled)")
+ assert.Equal(t, types.QueueOutcomeRejectedOther, outcome,
+ "outcome should be QueueOutcomeRejectedOther when cancelled during distribution")
+ })
+
+ // This test validates the retry mechanism when a processor reports that its shard is draining.
+ t.Run("RetriesAndSucceeds_OnProcessorReportsShardDraining", func(t *testing.T) {
+ t.Parallel()
+ var callCount atomic.Int32
+ mockRegistry := &mockRegistryClient{
+ WithConnectionFunc: func(
+ _ types.FlowKey,
+ fn func(conn contracts.ActiveFlowConnection) error,
+ ) error {
+ attempt := callCount.Add(1)
+ shardA := newMockShard("shard-A").withByteSize(100).build()
+ shardB := newMockShard("shard-B").withByteSize(1000).build()
+
+ if attempt == 1 {
+ // Attempt 1: Shard A is the least loaded and is selected.
+ return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA, shardB}})
+ }
+ // Attempt 2 (Retry): Assume Shard A is now draining and removed from the active set by the registry.
+ return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardB}})
+ },
+ }
+ // Use a long TTL to ensure retries don't time out.
+ h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 10 * time.Second}, mockRegistry)
+
+ // Configure Shard A's processor to reject the request due to draining.
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(item *internal.FlowItem) error {
+ // The processor accepts the item but then asynchronously finalizes it with ErrShardDraining.
+ item.SetHandle(&typesmocks.MockQueueItemHandle{})
+ go item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, contracts.ErrShardDraining)
+ return nil
+ },
+ }
+ // Configure Shard B's processor to successfully dispatch the request on the retry.
+ h.mockProcessorFactory.processors["shard-B"] = &mockShardProcessor{
+ SubmitFunc: func(item *internal.FlowItem) error {
+ go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ return nil
+ },
+ }
+
+ // Act
+ outcome, err := h.fc.EnqueueAndWait(context.Background(), newTestRequest(defaultFlowKey))
+
+ // Assert
+ require.NoError(t, err, "EnqueueAndWait must succeed after retrying on a healthy shard")
+ assert.Equal(t, types.QueueOutcomeDispatched, outcome, "outcome should be QueueOutcomeDispatched")
+ assert.Equal(t, int32(2), callCount.Load(), "registry must be consulted for Active shards on each retry attempt")
+ })
+ })
+
+ // Lifecycle covers the post-distribution phase, focusing on how the controller handles context cancellation and TTL
+ // expiry while the request is buffered or queued by the processor (Asynchronous Finalization).
+ t.Run("Lifecycle", func(t *testing.T) {
+ t.Parallel()
+
+ // Validates that the controller correctly initiates asynchronous finalization when the request context is cancelled
+ // after ownership has been transferred to the processor.
+ t.Run("OnReqCtxCancelledAfterDistribution", func(t *testing.T) {
+ t.Parallel()
+ // Use a long TTL to ensure the failure is due to cancellation.
+ h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 10 * time.Second}, nil)
+
+ shardA := newMockShard("shard-A").build()
+ h.mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(_ contracts.ActiveFlowConnection) error) error {
+ return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA}})
+ }
+
+ // Channel for synchronization.
+ itemSubmitted := make(chan *internal.FlowItem, 1)
+
+ // Configure the processor to accept the item but never finalize it, simulating a queued request.
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(item *internal.FlowItem) error {
+ item.SetHandle(&typesmocks.MockQueueItemHandle{})
+ itemSubmitted <- item
+ return nil
+ },
+ }
+
+ reqCtx, cancelReq := context.WithCancel(context.Background())
+ req := newTestRequest(defaultFlowKey)
+
+ var outcome types.QueueOutcome
+ var err error
+ done := make(chan struct{})
+ go func() {
+ outcome, err = h.fc.EnqueueAndWait(reqCtx, req)
+ close(done)
+ }()
+
+ // 1. Wait for the item to be successfully distributed.
+ var item *internal.FlowItem
+ select {
+ case item = <-itemSubmitted:
+ // Success. Ownership has transferred. EnqueueAndWait is now in the select loop.
+ case <-time.After(1 * time.Second):
+ t.Fatal("timed out waiting for item to be submitted to the processor")
+ }
+
+ // 2. Cancel the request context.
+ cancelReq()
+
+ // 3. Wait for EnqueueAndWait to return.
+ select {
+ case <-done:
+ // Success. The controller detected the cancellation and unblocked the caller.
+ case <-time.After(1 * time.Second):
+ t.Fatal("timed out waiting for EnqueueAndWait to return after cancellation")
+ }
+
+ // 4. Assertions for EnqueueAndWait's return values.
+ require.Error(t, err, "EnqueueAndWait should return an error when the request is cancelled post-distribution")
+ // The outcome should be Evicted (as the handle was set).
+ assert.ErrorIs(t, err, types.ErrEvicted, "error should wrap ErrEvicted")
+ // The underlying cause must be propagated.
+ assert.ErrorIs(t, err, types.ErrContextCancelled, "error should wrap ErrContextCancelled")
+ assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, outcome, "outcome should be EvictedContextCancelled")
+
+ // 5. Assert that the FlowItem itself was indeed finalized by the controller.
+ finalState := item.FinalState()
+ require.NotNil(t, finalState, "Item should have been finalized asynchronously by the controller")
+ assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, finalState.Outcome,
+ "Item's internal outcome must match the returned outcome")
+ })
+
+ // Validates the asynchronous finalization path due to TTL expiry.
+ // Note: This relies on real time passing, as context.WithDeadline timers cannot be controlled by FakeClock.
+ t.Run("OnReqCtxTimeoutAfterDistribution", func(t *testing.T) {
+ t.Parallel()
+ // Configure a short TTL to keep the test reasonably fast.
+ const requestTTL = 50 * time.Millisecond
+ h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: requestTTL}, nil)
+
+ shardA := newMockShard("shard-A").build()
+ h.mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(_ contracts.ActiveFlowConnection) error) error {
+ return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA}})
+ }
+
+ itemSubmitted := make(chan *internal.FlowItem, 1)
+
+ // Configure the processor to accept the item but never finalize it.
+ h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{
+ SubmitFunc: func(item *internal.FlowItem) error {
+ item.SetHandle(&typesmocks.MockQueueItemHandle{})
+ itemSubmitted <- item
+ return nil
+ },
+ }
+
+ req := newTestRequest(defaultFlowKey)
+ // Use a context for the call itself that won't time out independently.
+ enqueueCtx, enqueueCancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer enqueueCancel()
+
+ var outcome types.QueueOutcome
+ var err error
+ done := make(chan struct{})
+
+ startTime := time.Now() // Capture start time to validate duration.
+ go func() {
+ outcome, err = h.fc.EnqueueAndWait(enqueueCtx, req)
+ close(done)
+ }()
+
+ // 1. Wait for the item to be submitted.
+ var item *internal.FlowItem
+ select {
+ case item = <-itemSubmitted:
+ case <-time.After(1 * time.Second):
+ t.Fatal("timed out waiting for item to be submitted to the processor")
+ }
+
+ // 2.Wait for the TTL to expire (Real time). We do NOT call Step().
+ // Wait for EnqueueAndWait to return due to the TTL expiry.
+ select {
+ case <-done:
+ // Success. Now validate that enough time actually passed.
+ duration := time.Since(startTime)
+ assert.GreaterOrEqual(t, duration, requestTTL-30*time.Millisecond, // tolerance for CI environments
+ "EnqueueAndWait returned faster than the TTL allows, indicating the timer did not function correctly")
+ case <-time.After(1 * time.Second):
+ t.Fatal("timed out waiting for EnqueueAndWait to return after TTL expiry")
+ }
+
+ // 4. Assertions for EnqueueAndWait's return values.
+ require.Error(t, err, "EnqueueAndWait should return an error when TTL expires post-distribution")
+ assert.ErrorIs(t, err, types.ErrEvicted, "error should wrap ErrEvicted")
+ assert.ErrorIs(t, err, types.ErrTTLExpired, "error should wrap the underlying cause (types.ErrTTLExpired)")
+ assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "outcome should be EvictedTTL")
+
+ // 5. Assert FlowItem final state.
+ finalState := item.FinalState()
+ require.NotNil(t, finalState, "Item should have been finalized asynchronously by the controller")
+ assert.Equal(t, types.QueueOutcomeEvictedTTL, finalState.Outcome,
+ "Item's internal outcome must match the returned outcome")
+ })
+ })
+}
+
+// TestFlowController_WorkerManagement covers the lifecycle of the shard processors (workers), including startup,
+// reconciliation (garbage collection), and shutdown.
+func TestFlowController_WorkerManagement(t *testing.T) {
+ t.Parallel()
+
+ // Reconciliation validates that the controller correctly identifies and shuts down workers whose shards no longer
+ // exist in the registry.
+ t.Run("Reconciliation", func(t *testing.T) {
+ t.Parallel()
+
+ // Setup: A registry that initially knows about "shard-A" and "stale-shard", but later only reports "shard-A".
+ mockRegistry := &mockRegistryClient{
+ ShardStatsFunc: func() []contracts.ShardStats {
+ // The current state of the world according to the registry.
+ return []contracts.ShardStats{{ID: "shard-A"}}
+ }}
+ h := newUnitHarness(t, t.Context(), Config{}, mockRegistry)
+
+ // Pre-populate the controller with initial workers, simulating a previous state.
+ initialShards := []string{"shard-A", "stale-shard"}
+ for _, shardID := range initialShards {
+ currentShardID := shardID
+ // Initialize the processor mocks with the channel needed to synchronize startup.
+ h.mockProcessorFactory.processors[currentShardID] = &mockShardProcessor{runStarted: make(chan struct{})}
+ shard := &mocks.MockRegistryShard{IDFunc: func() string { return currentShardID }}
+ // Start the worker using the internal mechanism.
+ h.fc.getOrStartWorker(shard)
+ }
+ require.Len(t, h.mockProcessorFactory.processors, 2, "pre-condition: initial workers not set up correctly")
+
+ // Wait for all worker goroutines to have started and captured their contexts.
+ for id, p := range h.mockProcessorFactory.processors {
+ proc := p
+ select {
+ case <-proc.runStarted:
+ // Worker is running.
+ case <-time.After(2 * time.Second):
+ t.Fatalf("timed out waiting for worker %s to start", id)
+ }
+ }
+
+ // Act: Manually trigger the reconciliation logic.
+ h.fc.reconcileProcessors()
+
+ t.Run("StaleWorkerIsCancelled", func(t *testing.T) {
+ staleProc := h.mockProcessorFactory.processors["stale-shard"]
+ require.NotNil(t, staleProc.Context(), "precondition: stale processor context should have been captured")
+ // The context of the removed worker must be cancelled to signal shutdown.
+ select {
+ case <-staleProc.Context().Done():
+ // Success: Context was cancelled.
+ case <-time.After(100 * time.Millisecond):
+ t.Error("context of the stale worker was not cancelled during reconciliation")
+ }
+ })
+
+ t.Run("ActiveWorkerIsNotCancelled", func(t *testing.T) {
+ activeProc := h.mockProcessorFactory.processors["shard-A"]
+ require.NotNil(t, activeProc.Context(), "precondition: active processor context should have been captured")
+ // The context of an active worker must remain open.
+ select {
+ case <-activeProc.Context().Done():
+ t.Error("context of the active worker was incorrectly cancelled during reconciliation")
+ default:
+ // Success: Context is still active.
+ }
+ })
+
+ t.Run("WorkerMapIsUpdated", func(t *testing.T) {
+ // The stale worker must be removed from the controller's concurrent map.
+ _, ok := h.fc.workers.Load("stale-shard")
+ assert.False(t, ok, "stale worker must be deleted from the controller's map")
+ _, ok = h.fc.workers.Load("shard-A")
+ assert.True(t, ok, "active worker must remain in the controller's map")
+ })
+ })
+
+ // Validates that the reconciliation loop runs periodically based on the configured interval.
+ t.Run("Reconciliation_IsTriggeredByTicker", func(t *testing.T) {
+ t.Parallel()
+ const reconciliationInterval = 10 * time.Second
+ mockRegistry := &mockRegistryClient{}
+
+ // Count the number of times the reconciliation logic (which calls ShardStats) runs.
+ var reconcileCount atomic.Int32
+ mockRegistry.ShardStatsFunc = func() []contracts.ShardStats {
+ reconcileCount.Add(1)
+ return nil
+ }
+
+ h := newUnitHarness(t, t.Context(), Config{ProcessorReconciliationInterval: reconciliationInterval}, mockRegistry)
+ // Ensure we are using the FakeClock specifically for this test, as we need Step/HasWaiters.
+ require.NotNil(t, h.mockClock, "This test requires the harness to be using FakeClock")
+
+ // Wait for the reconciliation loop to start and create the ticker.
+ // This prevents a race where the clock is stepped before the ticker is registered with the FakeClock.
+ require.Eventually(t, h.mockClock.HasWaiters, time.Second, 10*time.Millisecond,
+ "reconciliation ticker was not created")
+
+ // Advance the clock to trigger the first reconciliation.
+ h.mockClock.Step(reconciliationInterval)
+
+ assert.Eventually(t, func() bool {
+ return reconcileCount.Load() == 1
+ }, time.Second, 10*time.Millisecond, "reconciliation was not triggered by the first ticker event")
+
+ // Advance the clock again to ensure it continues to fire.
+ h.mockClock.Step(reconciliationInterval)
+ assert.Eventually(t, func() bool {
+ return reconcileCount.Load() == 2
+ }, time.Second, 10*time.Millisecond, "reconciliation did not fire on the second ticker event")
+ })
+
+ // Validates the atomicity of worker creation and ensures resource cleanup for the loser of the race.
+ t.Run("WorkerCreationRace", func(t *testing.T) {
+ t.Parallel()
+
+ // This test orchestrates a deterministic race condition.
+ factoryEntered := make(chan *mockShardProcessor, 2)
+ continueFactory := make(chan struct{})
+ // Map to store the construction context for each processor instance, allowing us to verify cleanup.
+ constructionContexts := sync.Map{}
+
+ h := newUnitHarness(t, t.Context(), Config{}, nil)
+
+ // Inject a custom factory to control the timing of worker creation.
+ h.fc.shardProcessorFactory = func(
+ ctx context.Context, // The context created by getOrStartWorker for the potential new processor.
+ shard contracts.RegistryShard,
+ _ contracts.SaturationDetector,
+ _ clock.WithTicker,
+ _ time.Duration,
+ _ int,
+ _ logr.Logger,
+ ) shardProcessor {
+ // This function is called by getOrStartWorker before the LoadOrStore check.
+ proc := &mockShardProcessor{runStarted: make(chan struct{})}
+ constructionContexts.Store(proc, ctx) // Capture the construction context.
+
+ // Signal entry and then block, allowing another goroutine to enter.
+ factoryEntered <- proc
+ <-continueFactory
+ return proc
+ }
+
+ shard := newMockShard("race-shard").build()
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ // Start two goroutines that will race to create the same worker.
+ go func() {
+ defer wg.Done()
+ h.fc.getOrStartWorker(shard)
+ }()
+ go func() {
+ defer wg.Done()
+ h.fc.getOrStartWorker(shard)
+ }()
+
+ // 1. Wait for both goroutines to enter the factory and create their respective processor instances.
+ proc1 := <-factoryEntered
+ proc2 := <-factoryEntered
+
+ // 2. Unblock both goroutines, allowing them to race to workers.LoadOrStore.
+ close(continueFactory)
+ wg.Wait()
+
+ // 3. Identify the winner and the loser.
+ actual, ok := h.fc.workers.Load("race-shard")
+ require.True(t, ok, "a worker must have been successfully stored in the map")
+
+ storedWorker := actual.(*managedWorker)
+ winnerProc := storedWorker.processor.(*mockShardProcessor)
+
+ var loserProc *mockShardProcessor
+ if winnerProc == proc1 {
+ loserProc = proc2
+ } else {
+ loserProc = proc1
+ }
+
+ // 4. Validate the state of the winning processor.
+ // Wait for the Run method to be called on the winner (only the winner should start).
+ select {
+ case <-winnerProc.runStarted:
+ // Success.
+ case <-time.After(1 * time.Second):
+ t.Fatal("timed out waiting for the winning worker's Run method to be called")
+ }
+
+ // The winning processor's context must remain active.
+ require.NotNil(t, winnerProc.Context(), "winner's context should not be nil (Run was called)")
+ select {
+ case <-winnerProc.Context().Done():
+ t.Error("context of the winning worker should not be cancelled")
+ default:
+ // Success
+ }
+
+ // 5. Validate the state of the losing processor and resource cleanup.
+ // The losing processor's Run method must NOT be called.
+ select {
+ case <-loserProc.runStarted:
+ t.Error("Run was incorrectly called on the losing worker")
+ default:
+ // Success
+ }
+
+ // Verify the context created for the loser during construction was cancelled by getOrStartWorker.
+ loserCtxRaw, ok := constructionContexts.Load(loserProc)
+ require.True(t, ok, "loser processor construction context should have been captured")
+ loserCtx := loserCtxRaw.(context.Context)
+
+ select {
+ case <-loserCtx.Done():
+ // Success: Context was cancelled, preventing resource leaks.
+ case <-time.After(100 * time.Millisecond):
+ t.Error("context of the losing worker was not cancelled, this will leak resources")
+ }
+ })
+}
+
+// Helper function to create a realistic mock registry environment for integration/concurrency tests.
+func setupRegistryForConcurrency(t *testing.T, numShards int, flowKey types.FlowKey) *mockRegistryClient {
+ t.Helper()
+ mockRegistry := &mockRegistryClient{}
+ shards := make([]contracts.RegistryShard, numShards)
+
+ // Configure the shards and their dependencies required by the real ShardProcessor implementation.
+ for i := range numShards {
+ // Capture loop variables for closures.
+ shardID := fmt.Sprintf("shard-%d", i)
+ // Use high-fidelity mock queues (MockManagedQueue) that implement the necessary interfaces and synchronization.
+ currentQueue := &mocks.MockManagedQueue{FlowKeyV: flowKey}
+
+ shards[i] = &mocks.MockRegistryShard{
+ IDFunc: func() string { return shardID },
+ ManagedQueueFunc: func(_ types.FlowKey) (contracts.ManagedQueue, error) {
+ return currentQueue, nil
+ },
+ // Configuration required for ShardProcessor initialization and dispatch logic.
+ AllOrderedPriorityLevelsFunc: func() []int { return []int{flowKey.Priority} },
+ PriorityBandAccessorFunc: func(priority int) (framework.PriorityBandAccessor, error) {
+ if priority == flowKey.Priority {
+ return &frameworkmocks.MockPriorityBandAccessor{
+ PriorityV: priority,
+ IterateQueuesFunc: func(f func(framework.FlowQueueAccessor) bool) {
+ f(currentQueue.FlowQueueAccessor())
+ },
+ }, nil
+ }
+ return nil, fmt.Errorf("unexpected priority %d", priority)
+ },
+ // Configure dispatch policies (FIFO).
+ IntraFlowDispatchPolicyFunc: func(_ types.FlowKey) (framework.IntraFlowDispatchPolicy, error) {
+ return &frameworkmocks.MockIntraFlowDispatchPolicy{
+ SelectItemFunc: func(qa framework.FlowQueueAccessor) (types.QueueItemAccessor, error) {
+ return qa.PeekHead()
+ },
+ }, nil
+ },
+ InterFlowDispatchPolicyFunc: func(_ int) (framework.InterFlowDispatchPolicy, error) {
+ return &frameworkmocks.MockInterFlowDispatchPolicy{
+ SelectQueueFunc: func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) {
+ return currentQueue.FlowQueueAccessor(), nil
+ },
+ }, nil
+ },
+ // Configure stats reporting based on the live state of the mock queues.
+ StatsFunc: func() contracts.ShardStats {
+ return contracts.ShardStats{
+ ID: shardID,
+ TotalLen: uint64(currentQueue.Len()),
+ TotalByteSize: currentQueue.ByteSize(),
+ PerPriorityBandStats: map[int]contracts.PriorityBandStats{
+ flowKey.Priority: {
+ Len: uint64(currentQueue.Len()),
+ ByteSize: currentQueue.ByteSize(),
+ CapacityBytes: 1e9, // Effectively unlimited capacity to ensure dispatch success.
+ },
+ },
+ }
+ },
+ }
+ }
+
+ // Configure the registry connection.
+ mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(conn contracts.ActiveFlowConnection) error) error {
+ return fn(&mockActiveFlowConnection{ActiveShardsV: shards})
+ }
+ mockRegistry.ShardStatsFunc = func() []contracts.ShardStats {
+ stats := make([]contracts.ShardStats, len(shards))
+ for i, shard := range shards {
+ stats[i] = shard.Stats()
+ }
+ return stats
+ }
+ return mockRegistry
+}
+
+// TestFlowController_Concurrency_Distribution performs an integration test under high contention, using real
+// ShardProcessors.
+// It validates the thread-safety of the distribution logic and the overall system throughput.
+func TestFlowController_Concurrency_Distribution(t *testing.T) {
+ const (
+ numShards = 4
+ numGoroutines = 50
+ numRequests = 200
+ )
+
+ // Arrange
+ mockRegistry := setupRegistryForConcurrency(t, numShards, defaultFlowKey)
+
+ // Initialize the integration harness with real ShardProcessors.
+ h := newIntegrationHarness(t, t.Context(), Config{
+ // Use a generous buffer to focus the test on distribution logic rather than backpressure.
+ EnqueueChannelBufferSize: numRequests,
+ DefaultRequestTTL: 5 * time.Second,
+ ExpiryCleanupInterval: 100 * time.Millisecond,
+ }, mockRegistry)
+
+ // Act: Hammer the controller concurrently.
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+ outcomes := make(chan types.QueueOutcome, numRequests)
+
+ for i := range numGoroutines {
+ goroutineID := i
+ go func() {
+ defer wg.Done()
+ for j := range numRequests / numGoroutines {
+ req := newTestRequest(defaultFlowKey)
+ req.IDV = fmt.Sprintf("req-distrib-%d-%d", goroutineID, j)
+
+ // Use a reasonable timeout for the individual request context.
+ reqCtx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
+ defer cancel()
+
+ ctx := logr.NewContext(reqCtx, logr.Discard())
+ outcome, err := h.fc.EnqueueAndWait(ctx, req)
+ if err != nil {
+ // Use t.Errorf for concurrent tests to report failures without halting execution.
+ t.Errorf("EnqueueAndWait failed unexpectedly under load: %v", err)
+ }
+ outcomes <- outcome
+ }
+ }()
+ }
+
+ // Wait for all requests to complete.
+ wg.Wait()
+ close(outcomes)
+
+ // Assert: All requests should be successfully dispatched.
+ successCount := 0
+ for outcome := range outcomes {
+ if outcome == types.QueueOutcomeDispatched {
+ successCount++
+ }
+ }
+ require.Equal(t, numRequests, successCount,
+ "all concurrent requests must be dispatched successfully without errors or data races")
+}
+
+// TestFlowController_Concurrency_Backpressure specifically targets the blocking submission path (SubmitOrBlock) by
+// configuring the processors with zero buffer capacity.
+func TestFlowController_Concurrency_Backpressure(t *testing.T) {
+ if testing.Short() {
+ t.Skip("Skipping concurrency integration test in short mode.")
+ }
+ t.Parallel()
+
+ const (
+ numShards = 2
+ numGoroutines = 20
+ // Fewer requests than the distribution test, as the blocking path is inherently slower.
+ numRequests = 40
+ )
+
+ // Arrange: Set up the registry environment.
+ mockRegistry := setupRegistryForConcurrency(t, numShards, defaultFlowKey)
+
+ // Use the integration harness with a configuration designed to induce backpressure.
+ h := newIntegrationHarness(t, t.Context(), Config{
+ // Zero buffer forces immediate use of SubmitOrBlock if the processor loop is busy.
+ EnqueueChannelBufferSize: 0,
+ // Generous TTL to ensure timeouts are not the cause of failure.
+ DefaultRequestTTL: 10 * time.Second,
+ ExpiryCleanupInterval: 100 * time.Millisecond,
+ }, mockRegistry)
+
+ // Act: Concurrently submit requests.
+ var wg sync.WaitGroup
+ wg.Add(numGoroutines)
+ outcomes := make(chan types.QueueOutcome, numRequests)
+
+ for i := range numGoroutines {
+ goroutineID := i
+ go func() {
+ defer wg.Done()
+ for j := range numRequests / numGoroutines {
+ req := newTestRequest(defaultFlowKey)
+ req.IDV = fmt.Sprintf("req-backpressure-%d-%d", goroutineID, j)
+
+ // Use a reasonable timeout for the individual request context to ensure the test finishes promptly if a
+ // deadlock occurs.
+ reqCtx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
+ defer cancel()
+
+ outcome, err := h.fc.EnqueueAndWait(logr.NewContext(reqCtx, logr.Discard()), req)
+ if err != nil {
+ t.Errorf("EnqueueAndWait failed unexpectedly under backpressure for request %s: %v", req.ID(), err)
+ }
+ outcomes <- outcome
+ }
+ }()
+ }
+ wg.Wait()
+ close(outcomes)
+
+ // Assert: Verify successful dispatch despite high contention and zero buffer.
+ successCount := 0
+ for outcome := range outcomes {
+ if outcome == types.QueueOutcomeDispatched {
+ successCount++
+ }
+ }
+ require.Equal(t, numRequests, successCount,
+ "all concurrent requests should be dispatched successfully even under high contention and zero buffer capacity")
+}
diff --git a/pkg/epp/flowcontrol/controller/doc.go b/pkg/epp/flowcontrol/controller/doc.go
index 8c96bbc18..0d2ea3687 100644
--- a/pkg/epp/flowcontrol/controller/doc.go
+++ b/pkg/epp/flowcontrol/controller/doc.go
@@ -14,109 +14,48 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
-// Package controller contains the implementation of the `FlowController` engine.
+// Package controller contains the implementation of the FlowController engine.
//
// # Overview
//
-// The `FlowController` is the central processing engine of the flow control system. It is a sharded, high-throughput
-// component responsible for managing the lifecycle of all incoming requests—from initial submission via the synchronous
-// `EnqueueAndWait` method to a terminal outcome (dispatch, rejection, or eviction). It achieves this by orchestrating
-// its dependencies—the `contracts.FlowRegistry`, the pluggable `Policy` framework, and the
-// `contracts.SaturationDetector`—to make continuous, state-aware decisions.
+// The FlowController is the central processing engine of the Flow Control layer. It acts as a stateless supervisor that
+// orchestrates a pool of stateful workers (internal.ShardProcessor), managing the lifecycle of all incoming requests
+// from initial submission to a terminal outcome (dispatch, rejection, or eviction).
//
-// # Architecture: The Processor-Shard Relationship
+// # Architecture: Supervisor-Worker Pattern
//
-// The `FlowController` engine is designed around a clear separation of state and execution. This "control plane vs.
-// data plane" separation is key to enabling dynamic, concurrent-safe configuration updates.
+// This package implements a supervisor-worker pattern to achieve high throughput and dynamic scalability.
//
-// - The `contracts.FlowRegistry` is the **control plane**. It is the single source of truth for all configuration.
-// When an administrative action occurs (e.g., `RegisterOrUpdateFlow`), the `contracts.FlowRegistry` is responsible
-// for safely applying that change to each of its managed `contracts.RegistryShard` instances.
+// - The FlowController (Supervisor): The public-facing API of the system. Its primary responsibilities are to execute
+// a distribution algorithm to select the optimal worker for a new request and to manage the lifecycle of the worker
+// pool, ensuring it stays synchronized with the underlying shard topology defined by the contracts.FlowRegistry.
+// - The internal.ShardProcessor (Worker): A stateful, single-goroutine actor responsible for the entire lifecycle of
+// requests on a single shard. The supervisor manages a pool of these workers, one for each contracts.RegistryShard.
//
-// - The `contracts.RegistryShard` is the **concurrent-safe state port**. It defines the contract for a state store
-// that holds the `contracts.ManagedQueue` and framework `Policy` instances for a single shard.
+// # Concurrency Model
//
-// - The `internal.ShardProcessor` is the **data plane worker**. It is given a single `contracts.RegistryShard` to
-// operate on. Its main `dispatchCycle` continuously acquires a read lock on the shard to get a consistent view of
-// the active queues and policies, and then executes its dispatch logic.
+// The FlowController is designed to be highly concurrent and thread-safe. It acts primarily as a stateless distributor.
//
-// This separation is what enables dynamic updates. The `internal.ShardProcessor` is stateless; it simply executes
-// against the state presented by its `contracts.RegistryShard` on each cycle. This allows the control plane
-// (`contracts.FlowRegistry`) to safely change that state in the background.
+// - EnqueueAndWait: Can be called concurrently by many goroutines.
+// - Worker Management: Uses a sync.Map (workers) for concurrent access and lazy initialization of workers.
+// - Supervision: A single background goroutine (run) manages the worker pool lifecycle (garbage collection).
//
-// # Architectural Deep Dive: The `EnqueueAndWait` Model
+// It achieves high throughput by minimizing shared state and relying on the internal ShardProcessors to handle state
+// mutations serially (using an actor model).
//
-// A fundamental design choice is the synchronous, blocking `EnqueueAndWait` method. In the context of the Gateway API
-// Inference Extension's Endpoint Picker (EPP), which operates as an Envoy External Processing (`ext_proc`) server, this
-// model is deliberately chosen for its simplicity and robustness.
+// # Request Lifecycle and Ownership
//
-// - Alignment with `ext_proc`: The `ext_proc` protocol is stream-based. A single goroutine within the EPP manages the
-// stream for a given HTTP request. `EnqueueAndWait` fits this perfectly: the request-handling goroutine calls it,
-// blocks, and upon return, has the definitive outcome. It can then immediately act on that outcome, maintaining
-// clear request-goroutine affinity.
+// A request (represented internally as a FlowItem) has a lifecycle managed cooperatively by the Controller and a
+// Processor. Defining ownership is critical for ensuring an item is finalized exactly once.
//
-// - Simplified State Management: The state of a "waiting" request is implicitly managed by the blocked goroutine's
-// stack and its `context.Context`. The `FlowController` only needs to signal this specific goroutine to unblock it.
+// 1. Submission (Controller): The Controller attempts to hand off the item to a Processor.
+// 2. Handoff:
+// - Success: Ownership transfers to the Processor, which is now responsible for Finalization.
+// - Failure: Ownership remains with the Controller, which must Finalize the item.
+// 3. Processing (Processor): The Processor enqueues, manages, and eventually dispatches or rejects the item.
+// 4. Finalization: The terminal outcome is set. This can happen:
+// - Synchronously: The Processor determines the outcome (e.g., Dispatch, Capacity Rejection).
+// - Asynchronously: The Controller observes the request's Context expiry (TTL/Cancellation) and calls Finalize.
//
-// - Direct Backpressure: If queues are full, `EnqueueAndWait` returns `types.ErrQueueAtCapacity`. This provides
-// immediate, direct backpressure to the earliest point of contact.
-//
-// # Architectural Deep Dive: The Sharded Model & JSQ-Bytes
-//
-// The `FlowController` is built on a sharded architecture to enable parallel processing and prevent a central dispatch
-// loop from becoming a bottleneck. The `FlowController` consists of a top-level manager and a pool of independent
-// `internal.ShardProcessor` workers. The `contracts.FlowRegistry` guarantees that every logical flow is represented by
-// a distinct queue instance on every active shard.
-//
-// This architecture trades deterministic global state for high throughput and scalability. The key challenge, and the
-// system's most critical assumption, revolves around ensuring this distributed model can still achieve global fairness
-// objectives.
-//
-// ## The Critical Assumption: Homogeneity Within Flows
-//
-// The effectiveness of the sharded model hinges on a critical assumption: while the system as a whole manages a
-// heterogeneous set of flows, the traffic *within a single logical flow* is assumed to be roughly homogeneous in its
-// characteristics. A logical flow is intended to represent a single workload or tenant; therefore, the most
-// unpredictable variables (effecting decode behavior) are expected to be statistically similar *within* that flow.
-//
-// ## The Hedge: Join the Shortest Queue by Bytes (JSQ-Bytes)
-//
-// To make this assumption as robust as possible, the `FlowController` uses a "Join the Shortest Queue by Bytes
-// (JSQ-Bytes)" algorithm. `ByteSize` is an excellent proxy for the resources the `FlowController` explicitly manages
-// (host memory pressure and queuing capacity) and is also a reasonable proxy for prefill compute time.
-//
-// Crucially, the goal of the distributor is not to perfectly predict backend compute time, but to intelligently balance
-// the load at the controller level. JSQ-Bytes achieves this by:
-//
-// 1. Reflecting True Load: It distributes work based on each shard's current queue size in bytes—a direct measure of
-// its memory and capacity congestion.
-//
-// 2. Adapting to Congestion: The byte-size of a queue is a real-time signal of a shard's overall congestion. If a
-// shard is slow (e.g., due to long-decoding downstream requests), its queues will remain full, and JSQ-Bytes will
-// adaptively steer new work away.
-//
-// 3. Hedging Against Assumption Violations: This adaptive, self-correcting nature makes it a powerful hedge. It
-// doesn't just distribute; it actively *load balances* based on the most relevant feedback available.
-//
-// # Architectural Deep Dive: Pre-Policy Gating
-//
-// Before policies are invoked, the `internal.ShardProcessor` applies an `internal.BandFilter`. This function determines
-// which flows within a priority band are eligible for a given operation (e.g., dispatch). This pattern is a deliberate
-// architectural choice to decouple the logic of *viability* from the logic of *selection*.
-//
-// - An `internal.BandFilter` (e.g., the `internal.NewSaturationFilter`) determines if a flow is viable based on
-// external signals like backend load.
-// - The `framework.InterFlowDispatchPolicy` then selects from among the viable candidates based on its own fairness
-// logic.
-//
-// This abstraction provides two major benefits:
-//
-// 1. Low Contributor Burden: It makes the mental model for policy contributors significantly simpler. An author of a
-// new fairness policy does not need to be concerned with the complexities of saturation detection or other gating
-// concerns. They are given a simple, pre-filtered view of the world and can focus solely on their selection logic.
-//
-// 2. Correctness by Construction: The `internal.subsetPriorityBandAccessor` wrapper guarantees that a policy operates
-// on a consistent, filtered view, regardless of which accessor method it calls (`FlowIDs`, `Queue`, etc.). This
-// prevents an entire class of subtle bugs where a policy might otherwise see a stale or unfiltered view of the
-// system state.
+// The FlowItem uses atomic operations to safely coordinate the Finalization state across goroutines.
package controller
diff --git a/pkg/epp/flowcontrol/controller/internal/doc.go b/pkg/epp/flowcontrol/controller/internal/doc.go
index 3f39b5791..0599d5387 100644
--- a/pkg/epp/flowcontrol/controller/internal/doc.go
+++ b/pkg/epp/flowcontrol/controller/internal/doc.go
@@ -14,34 +14,23 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
-// Package internal provides the core worker implementation for the `controller.FlowController`.
+// Package internal provides the core worker implementation for the controller.FlowController.
//
-// The components in this package are the private, internal building blocks of the `controller` package. This separation
-// enforces a clean public API at the `controller` level and allows the internal mechanics of the engine to evolve
-// independently.
+// The components in this package are the private, internal building blocks of the controller. This separation enforces
+// a clean public API at the `controller` level and allows the internal mechanics of the engine to evolve independently.
//
-// # Design Philosophy: A Single-Writer Actor Model
+// # Design Philosophy: The Single-Writer Actor Model
//
-// The concurrency model for this package is deliberately built around a single-writer, channel-based actor pattern, as
-// implemented in the `ShardProcessor`. While a simple lock-based approach might seem easier, it is insufficient for the
-// system's requirements. The "enqueue" operation is a complex, stateful transaction that requires a **hierarchical
-// capacity check** against both the overall shard and a request's specific priority band.
+// The concurrency model for this package is built around a single-writer, channel-based actor pattern, as implemented
+// in the ShardProcessor. All state-mutating operations for a given shard (primarily enqueuing new requests) are
+// funneled through a single Run goroutine.
//
-// A coarse, shard-wide lock would be required to make this transaction atomic, creating a major performance bottleneck
-// and causing head-of-line blocking at the top-level `controller.FlowController`. The single-writer model, where all
-// state mutations are funneled through a single goroutine, makes this transaction atomic *without locks*.
+// This design makes complex, multi-step transactions (like a hierarchical capacity check against both a shard's total
+// limit and a priority band's limit) inherently atomic without locks. This avoids the performance bottleneck of a
+// coarse, shard-wide lock and allows the top-level Controller to remain decoupled and highly concurrent.
//
-// This design provides two critical benefits:
-// 1. **Decoupling:** The `controller.FlowController` is decoupled via a non-blocking channel send, allowing for much
-// higher throughput.
-// 2. **Backpressure:** The state of the channel buffer serves as a high-fidelity, real-time backpressure signal,
-// enabling more intelligent load balancing.
+// # Key Components
//
-// # Future-Proofing for Complex Transactions
-//
-// This model's true power is that it provides a robust foundation for future features like **displacement** (a
-// high-priority item evicting lower-priority ones). This is an "all-or-nothing" atomic transaction that is
-// exceptionally difficult to implement correctly in a lock-free or coarse-grained locking model without significant
-// performance penalties. The single-writer model contains the performance cost of such a potentially long transaction
-// to the single `ShardProcessor`, preventing it from blocking the entire `controller.FlowController`.
+// - ShardProcessor: The implementation of the worker actor. Manages the lifecycle of requests for a single shard.
+// - FlowItem: The internal representation of a request, managing its state and synchronization across goroutines.
package internal
diff --git a/pkg/epp/flowcontrol/controller/internal/filter.go b/pkg/epp/flowcontrol/controller/internal/filter.go
deleted file mode 100644
index 0a0669224..000000000
--- a/pkg/epp/flowcontrol/controller/internal/filter.go
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
-Copyright 2025 The Kubernetes Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package internal
-
-import (
- "context"
-
- "github.com/go-logr/logr"
-
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
-)
-
-// BandFilter is a function that acts as a pre-policy gate. It takes a complete view of a priority band and returns a
-// potentially filtered `framework.PriorityBandAccessor` containing only the flows that are viable candidates for a
-// subsequent policy decision. It can also return a boolean signal to pause the entire operation for the band.
-//
-// This abstraction decouples the logic of determining *viability* (e.g., is a flow subject to backpressure?) from the
-// logic of *selection* (e.g., which of the viable flows is the fairest to pick next?). This separation simplifies the
-// mental model for policy authors, who can focus solely on selection logic without needing to account for external
-// gating signals.
-//
-// Because filters are applied before the band is passed to a policy, they are inherently composable. Multiple filters
-// can be chained to apply different viability criteria. For example, a future filter could be developed to temporarily
-// exclude a "misbehaving" flow that is causing repeated errors, quarantining it from policy consideration.
-//
-// A nil returned `PriorityBandAccessor` indicates that no filtering was necessary and the original accessor should be
-// used. This provides a zero-allocation fast path for the common case where no flows are being filtered.
-type BandFilter func(
- ctx context.Context,
- band framework.PriorityBandAccessor,
- logger logr.Logger,
-) (filteredBand framework.PriorityBandAccessor, shouldPause bool)
-
-// NewSaturationFilter creates a `BandFilter` that uses the provided `contracts.SaturationDetector` to determine which
-// flows are dispatchable. This is the standard filter used in the production `FlowController` for the dispatch
-// operation.
-func NewSaturationFilter(sd contracts.SaturationDetector) BandFilter {
- return func(
- ctx context.Context,
- band framework.PriorityBandAccessor,
- logger logr.Logger,
- ) (framework.PriorityBandAccessor, bool) {
- // Phase 1: Implement the current global saturation check.
- if sd.IsSaturated(ctx) {
- logger.V(logutil.VERBOSE).Info("System saturated, pausing dispatch for this shard.")
- return nil, true // Pause dispatching for all bands.
- }
-
- // Phase 2 (Future): This is where per-flow saturation logic would go.
- // It would iterate `band`, call `IsSaturated(ctx, flowID)`, and build a filtered map of allowed flows,
- // then return `newSubsetPriorityBandAccessor(band, allowedFlows)`.
- // For now, no per-flow filtering is done. Return nil to signal the fast path.
- return nil, false // Do not pause, and do not filter any flows.
- }
-}
-
-// subsetPriorityBandAccessor provides a view of a priority band that is restricted to a specific subset of flows.
-// It implements the `framework.PriorityBandAccessor` interface, ensuring that any policy operating on it will only
-// see the allowed flows, regardless of which accessor method is used. This provides correctness by construction.
-//
-// For performance, it pre-computes a slice of the allowed flows at creation time, making subsequent calls to
-// `FlowKeys()` an O(1) operation with zero allocations.
-type subsetPriorityBandAccessor struct {
- originalAccessor framework.PriorityBandAccessor
- allowedFlows map[types.FlowKey]struct{}
- allowedFlowsSlice []types.FlowKey
-}
-
-var _ framework.PriorityBandAccessor = &subsetPriorityBandAccessor{}
-
-// newSubsetPriorityBandAccessor creates a new filtered view of a priority band.
-func newSubsetPriorityBandAccessor(original framework.PriorityBandAccessor, allowed []types.FlowKey) *subsetPriorityBandAccessor {
- // Pre-compute the map for efficient lookups in `Queue()` and `IterateQueues()`.
- allowedMap := make(map[types.FlowKey]struct{}, len(allowed))
- for _, k := range allowed {
- allowedMap[k] = struct{}{}
- }
-
- return &subsetPriorityBandAccessor{
- originalAccessor: original,
- allowedFlows: allowedMap,
- allowedFlowsSlice: allowed,
- }
-}
-
-// Priority returns the numerical priority level of this band.
-func (s *subsetPriorityBandAccessor) Priority() uint {
- return s.originalAccessor.Priority()
-}
-
-// PriorityName returns the human-readable name of this priority band.
-func (s *subsetPriorityBandAccessor) PriorityName() string {
- return s.originalAccessor.PriorityName()
-}
-
-// FlowKeys returns a slice of the composite `types.FlowKey`s for every flow instance currently active within this
-// priority band that are in the allowed subset.
-// This is an O(1) operation because the slice is pre-computed at creation.
-func (s *subsetPriorityBandAccessor) FlowKeys() []types.FlowKey {
- return s.allowedFlowsSlice
-}
-
-// Queue returns a `framework.FlowQueueAccessor` for the specified `ID` within this priority band, but only if it is
-// in the allowed subset. This is an O(1) map lookup. If the flow is not in the allowed subset, it returns nil.
-func (s *subsetPriorityBandAccessor) Queue(id string) framework.FlowQueueAccessor {
- key := types.FlowKey{ID: id, Priority: s.Priority()}
- if _, ok := s.allowedFlows[key]; !ok {
- return nil
- }
- return s.originalAccessor.Queue(id)
-}
-
-// IterateQueues executes the given `callback` for each `framework.FlowQueueAccessor` in the allowed subset of this
-// priority band. The iteration stops if the callback returns false.
-// This implementation delegates to the original accessor's iterator and applies the filter, which is more robust and
-// efficient than iterating over a pre-computed slice of IDs.
-func (s *subsetPriorityBandAccessor) IterateQueues(callback func(queue framework.FlowQueueAccessor) bool) {
- s.originalAccessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool {
- if _, ok := s.allowedFlows[queue.FlowKey()]; ok {
- // This queue is in the allowed set, so execute the callback.
- if !callback(queue) {
- return false // The callback requested to stop, so we stop the outer iteration too.
- }
- }
- return true // Continue iterating over the original set.
- })
-}
diff --git a/pkg/epp/flowcontrol/controller/internal/filter_test.go b/pkg/epp/flowcontrol/controller/internal/filter_test.go
deleted file mode 100644
index ceff9e83f..000000000
--- a/pkg/epp/flowcontrol/controller/internal/filter_test.go
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
-Copyright 2025 The Kubernetes Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package internal
-
-import (
- "context"
- "sort"
- "testing"
-
- "github.com/go-logr/logr"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts/mocks"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework"
- frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
-)
-
-func TestNewSaturationFilter(t *testing.T) {
- t.Parallel()
-
- testCases := []struct {
- name string
- isSaturated bool
- expectShouldPause bool
- expectFilteredBandNil bool
- }{
- {
- name: "should not pause or filter when system is not saturated",
- isSaturated: false,
- expectShouldPause: false,
- expectFilteredBandNil: true, // nil band signals the fast path
- },
- {
- name: "should pause when system is saturated",
- isSaturated: true,
- expectShouldPause: true,
- expectFilteredBandNil: true,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
-
- // --- ARRANGE ---
- mockSD := &mocks.MockSaturationDetector{IsSaturatedFunc: func(ctx context.Context) bool { return tc.isSaturated }}
- filter := NewSaturationFilter(mockSD)
- require.NotNil(t, filter, "NewSaturationFilter should not return nil")
-
- mockBand := &frameworkmocks.MockPriorityBandAccessor{}
-
- // --- ACT ---
- filteredBand, shouldPause := filter(context.Background(), mockBand, logr.Discard())
-
- // --- ASSERT ---
- assert.Equal(t, tc.expectShouldPause, shouldPause, "The filter's pause signal should match the expected value")
-
- if tc.expectFilteredBandNil {
- assert.Nil(t, filteredBand, "Expected filtered band to be nil")
- } else {
- assert.NotNil(t, filteredBand, "Expected a non-nil filtered band")
- }
- })
- }
-}
-
-func TestSubsetPriorityBandAccessor(t *testing.T) {
- t.Parallel()
-
- // --- ARRANGE ---
- // Setup a mock original accessor that knows about three flows.
- flowAKey := types.FlowKey{ID: "flow-a", Priority: 10}
- flowBKey := types.FlowKey{ID: "flow-b", Priority: 10}
- flowCKey := types.FlowKey{ID: "flow-c", Priority: 10}
-
- mockQueueA := &frameworkmocks.MockFlowQueueAccessor{FlowKeyV: flowAKey}
- mockQueueB := &frameworkmocks.MockFlowQueueAccessor{FlowKeyV: flowBKey}
- mockQueueC := &frameworkmocks.MockFlowQueueAccessor{FlowKeyV: flowCKey}
-
- originalAccessor := &frameworkmocks.MockPriorityBandAccessor{
- PriorityV: 10,
- PriorityNameV: "High",
- FlowKeysFunc: func() []types.FlowKey {
- return []types.FlowKey{flowAKey, flowBKey, flowCKey}
- },
- QueueFunc: func(id string) framework.FlowQueueAccessor {
- switch id {
- case "flow-a":
- return mockQueueA
- case "flow-b":
- return mockQueueB
- case "flow-c":
- return mockQueueC
- }
- return nil
- },
- IterateQueuesFunc: func(callback func(queue framework.FlowQueueAccessor) bool) {
- if !callback(mockQueueA) {
- return
- }
- if !callback(mockQueueB) {
- return
- }
- callback(mockQueueC)
- },
- }
-
- // Create a subset view that only allows two of the flows.
- allowedFlows := []types.FlowKey{flowAKey, flowCKey}
- subsetAccessor := newSubsetPriorityBandAccessor(originalAccessor, allowedFlows)
- require.NotNil(t, subsetAccessor, "newSubsetPriorityBandAccessor should not return nil")
-
- t.Run("should pass through priority and name", func(t *testing.T) {
- t.Parallel()
- assert.Equal(t, uint(10), subsetAccessor.Priority(), "Priority() should pass through from the original accessor")
- assert.Equal(t, "High", subsetAccessor.PriorityName(),
- "PriorityName() should pass through from the original accessor")
- })
-
- t.Run("should only return allowed flow keys", func(t *testing.T) {
- t.Parallel()
- flowKeys := subsetAccessor.FlowKeys()
- // Sort for consistent comparison, as the pre-computed slice order is not guaranteed.
- sort.Slice(flowKeys, func(i, j int) bool {
- return flowKeys[i].ID < flowKeys[j].ID
- })
- assert.Equal(t, []types.FlowKey{flowAKey, flowCKey}, flowKeys, "FlowKeys() should only return the allowed subset")
- })
-
- t.Run("should only return queues for allowed flows", func(t *testing.T) {
- t.Parallel()
- assert.Same(t, mockQueueA, subsetAccessor.Queue("flow-a"), "Should return queue for allowed flow 'a'")
- assert.Nil(t, subsetAccessor.Queue("flow-b"), "Should not return queue for disallowed flow 'b'")
- assert.Same(t, mockQueueC, subsetAccessor.Queue("flow-c"), "Should return queue for allowed flow 'c'")
- })
-
- t.Run("should only iterate over allowed queues", func(t *testing.T) {
- t.Parallel()
- var iterated []string
- subsetAccessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool {
- iterated = append(iterated, queue.FlowKey().ID)
- return true
- })
- // Sort for consistent comparison, as iteration order is not guaranteed.
- sort.Strings(iterated)
- assert.Equal(t, []string{"flow-a", "flow-c"}, iterated, "IterateQueues() should only visit allowed flows")
- })
-
- t.Run("should stop iteration if callback returns false", func(t *testing.T) {
- t.Parallel()
- var iterated []string
- subsetAccessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool {
- iterated = append(iterated, queue.FlowKey().ID)
- return false // Exit after the first item.
- })
- assert.Len(t, iterated, 1, "Iteration should have stopped after one item")
- })
-}
diff --git a/pkg/epp/flowcontrol/controller/internal/item.go b/pkg/epp/flowcontrol/controller/internal/item.go
index 86aeb8a0c..f0d5d3286 100644
--- a/pkg/epp/flowcontrol/controller/internal/item.go
+++ b/pkg/epp/flowcontrol/controller/internal/item.go
@@ -17,141 +17,177 @@ limitations under the License.
package internal
import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
"sync"
"sync/atomic"
"time"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
)
-// flowItem is the internal representation of a request managed by the `FlowController`. It implements the
-// `types.QueueItemAccessor` interface, which is the primary view of the item used by queue and policy implementations.
-// It wraps the original `types.FlowControlRequest` and adds metadata for queuing, lifecycle management, and policy
-// interaction.
+// FinalState encapsulates the terminal outcome of a FlowItem's lifecycle.
+type FinalState struct {
+ Outcome types.QueueOutcome
+ Err error
+}
+
+// FlowItem is the internal representation of a request managed by the Flow Controller.
+//
+// # Lifecycle Management
+//
+// Finalization (determining outcome) can be initiated by the Controller (e.g., Context expiry) or the Processor (e.g.,
+// Dispatch/Reject). It sets the outcome and signals the waiting goroutine.
//
-// # Concurrency
+// # Synchronization
//
-// The `finalize` method is the primary point of concurrency concern. It is designed to be atomic and idempotent through
-// the use of `sync.Once`. This guarantees that an item's final outcome can be set exactly once, even if multiple
-// goroutines (e.g., the main dispatch loop and the expiry cleanup loop) race to finalize it. All other fields are set
-// at creation time and are not modified thereafter, making them safe for concurrent access.
-type flowItem struct {
- // enqueueTime is the timestamp when the item was logically accepted by the `FlowController`.
- enqueueTime time.Time
- // effectiveTTL is the actual time-to-live assigned to this item.
- effectiveTTL time.Duration
- // originalRequest is the underlying request object.
+// Atomic operations synchronize state across the Controller and Processor goroutines:
+// - finalState (atomic.Pointer): Safely publishes the outcome.
+// - handle (atomic.Pointer): Safely publishes the queue admission status.
+type FlowItem struct {
+ // --- Immutable fields during a single lifecycle ---
+
+ enqueueTime time.Time
+ effectiveTTL time.Duration
originalRequest types.FlowControlRequest
- // handle is the unique identifier for this item within a specific queue instance.
- handle types.QueueItemHandle
-
- // done is closed exactly once when the item is finalized (dispatched or evicted/rejected).
- done chan struct{}
- // err stores the final error state if the item was not successfully dispatched.
- // It is written to exactly once, protected by `onceFinalize`.
- err atomic.Value // Stores error
- // outcome stores the final `types.QueueOutcome` of the item's lifecycle.
- // It is written to exactly once, protected by `onceFinalize`.
- outcome atomic.Value // Stores `types.QueueOutcome`
- // onceFinalize ensures the `finalize()` logic is idempotent.
+
+ // --- Synchronized State ---
+
+ // handle stores the types.QueueItemHandle atomically.
+ // Written by the Processor (SetHandle) when admitted.
+ // Read by inferOutcome (called by Finalize) to infer the outcome (Rejected vs. Evicted).
+ // Distinguishing between pre-admission (Rejection) and post-admission (Eviction) during asynchronous finalization
+ // relies on whether this handle is nil or non-nil.
+ handle atomic.Pointer[types.QueueItemHandle]
+
+ // finalState holds the result of the finalization. Stored atomically once.
+ // Use FinalState() for safe access.
+ finalState atomic.Pointer[FinalState]
+
+ // --- Finalization Signaling ---
+
+ // done is the channel used to signal the completion of the item's lifecycle.
+ // Buffered to size 1 to prevent Finalize from blocking.
+ done chan *FinalState
+
+ // onceFinalize ensures the finalization logic runs exactly once per lifecycle.
onceFinalize sync.Once
}
-// ensure flowItem implements the interface.
-var _ types.QueueItemAccessor = &flowItem{}
+var _ types.QueueItemAccessor = &FlowItem{}
-// NewItem creates a new `flowItem`, which is the internal representation of a request inside the `FlowController`.
-// This constructor is exported so that the parent `controller` package can create items to be passed into the
-// `internal` package's processors. It initializes the item with a "NotYetFinalized" outcome and an open `done` channel.
-func NewItem(req types.FlowControlRequest, effectiveTTL time.Duration, enqueueTime time.Time) *flowItem {
- fi := &flowItem{
+// NewItem allocates and initializes a new FlowItem for a request lifecycle.
+func NewItem(req types.FlowControlRequest, effectiveTTL time.Duration, enqueueTime time.Time) *FlowItem {
+ return &FlowItem{
enqueueTime: enqueueTime,
effectiveTTL: effectiveTTL,
originalRequest: req,
- done: make(chan struct{}),
+ done: make(chan *FinalState, 1),
}
- // Initialize the outcome to its zero state.
- fi.outcome.Store(types.QueueOutcomeNotYetFinalized)
- return fi
}
-// EnqueueTime returns the time the item was logically accepted by the `FlowController` for queuing. This is used as the
-// basis for TTL calculations.
-func (fi *flowItem) EnqueueTime() time.Time { return fi.enqueueTime }
+// EnqueueTime returns the time the item was logically accepted by the FlowController.
+func (fi *FlowItem) EnqueueTime() time.Time { return fi.enqueueTime }
-// EffectiveTTL returns the actual time-to-live assigned to this item by the `FlowController`.
-func (fi *flowItem) EffectiveTTL() time.Duration { return fi.effectiveTTL }
+// EffectiveTTL returns the actual time-to-live assigned to this item.
+func (fi *FlowItem) EffectiveTTL() time.Duration { return fi.effectiveTTL }
-// OriginalRequest returns the original, underlying `types.FlowControlRequest` object.
-func (fi *flowItem) OriginalRequest() types.FlowControlRequest { return fi.originalRequest }
+// OriginalRequest returns the original types.FlowControlRequest object.
+func (fi *FlowItem) OriginalRequest() types.FlowControlRequest { return fi.originalRequest }
-// Handle returns the `types.QueueItemHandle` that uniquely identifies this item within a specific queue instance. It
-// returns nil if the item has not yet been added to a queue.
-func (fi *flowItem) Handle() types.QueueItemHandle { return fi.handle }
+// Done returns a read-only channel that will receive the FinalState pointer exactly once.
+func (fi *FlowItem) Done() <-chan *FinalState { return fi.done }
-// SetHandle associates a `types.QueueItemHandle` with this item. This method is called by a `framework.SafeQueue`
-// implementation immediately after the item is added to the queue.
-func (fi *flowItem) SetHandle(handle types.QueueItemHandle) { fi.handle = handle }
+// FinalState returns the FinalState if the item has been finalized, or nil otherwise.
+// Safe for concurrent access.
+func (fi *FlowItem) FinalState() *FinalState { return fi.finalState.Load() }
-// Done returns a channel that is closed when the item has been finalized (e.g., dispatched or evicted).
-// This is the primary mechanism for consumers to wait for an item's outcome. It is designed to be used in a `select`
-// statement, allowing the caller to simultaneously wait for other events, such as context cancellation.
-//
-// # Example Usage
-//
-// select {
-// case <-item.Done():
-// outcome, err := item.FinalState()
-// // ... handle outcome
-// case <-ctx.Done():
-// // ... handle cancellation
-// }
-func (fi *flowItem) Done() <-chan struct{} {
- return fi.done
+// Handle returns the types.QueueItemHandle for this item within a queue.
+// Returns nil if the item is not in a queue. Safe for concurrent access.
+func (fi *FlowItem) Handle() types.QueueItemHandle {
+ ptr := fi.handle.Load()
+ if ptr == nil {
+ return nil
+ }
+ return *ptr
}
-// FinalState returns the terminal outcome and error for the item.
-//
-// CRITICAL: This method must only be called after the channel returned by `Done()` has been closed. Calling it before
-// the item is finalized may result in a race condition where the final state has not yet been written.
-func (fi *flowItem) FinalState() (types.QueueOutcome, error) {
- outcomeVal := fi.outcome.Load()
- errVal := fi.err.Load()
-
- var finalOutcome types.QueueOutcome
- if oc, ok := outcomeVal.(types.QueueOutcome); ok {
- finalOutcome = oc
- } else {
- // This case should not happen if finalize is always called correctly, but we default to a safe value.
- finalOutcome = types.QueueOutcomeNotYetFinalized
- }
+// SetHandle associates a types.QueueItemHandle with this item. Called by the queue implementation (via Processor).
+// Safe for concurrent access.
+func (fi *FlowItem) SetHandle(handle types.QueueItemHandle) { fi.handle.Store(&handle) }
- var finalErr error
- if e, ok := errVal.(error); ok {
- finalErr = e
- }
- return finalOutcome, finalErr
+// Finalize determines the item's terminal state based on the provided cause (e.g., Context error) and the item's
+// current admission status (queued or not).
+//
+// This method is intended for asynchronous finalization initiated by the Controller (e.g., TTL expiry).
+// It is idempotent.
+func (fi *FlowItem) Finalize(cause error) {
+ fi.onceFinalize.Do(func() {
+ // Atomically load the handle to determine if the item was admitted to a queue.
+ // This synchronization is critical for correctly inferring the outcome across goroutines.
+ isQueued := fi.Handle() != nil
+ outcome, finalErr := inferOutcome(cause, isQueued)
+ fi.finalizeInternal(outcome, finalErr)
+ })
}
-// finalize sets the item's terminal state (`outcome`, `error`) and closes its `done` channel idempotently using
-// `sync.Once`. This is the single, internal point where an item's lifecycle within the `FlowController` concludes.
-func (fi *flowItem) finalize(outcome types.QueueOutcome, err error) {
+// FinalizeWithOutcome sets the item's terminal state explicitly.
+//
+// This method is intended for synchronous finalization by the Processor (Dispatch, Reject) or the Controller
+// (Distribution failure).
+// It is idempotent.
+func (fi *FlowItem) FinalizeWithOutcome(outcome types.QueueOutcome, err error) {
fi.onceFinalize.Do(func() {
- if err != nil {
- fi.err.Store(err)
- }
- fi.outcome.Store(outcome)
- close(fi.done)
+ fi.finalizeInternal(outcome, err)
})
}
-// isFinalized checks if the item has been finalized without blocking. It is used internally by the `ShardProcessor` as
-// a defensive check to avoid operating on items that have already been completed.
-func (fi *flowItem) isFinalized() bool {
- select {
- case <-fi.done:
- return true
+// finalizeInternal is the core finalization logic. It must be called within the sync.Once.Do block.
+// It captures the state, stores it atomically, and signals the Done channel.
+func (fi *FlowItem) finalizeInternal(outcome types.QueueOutcome, err error) {
+ finalState := &FinalState{
+ Outcome: outcome,
+ Err: err,
+ }
+
+ // Atomically store the pointer. This is the critical memory barrier that publishes the state safely.
+ fi.finalState.Store(finalState)
+
+ duration := time.Since(fi.enqueueTime)
+ flowKey := fi.originalRequest.FlowKey()
+ metrics.RecordFlowControlRequestQueueDuration(flowKey.ID, strconv.Itoa(flowKey.Priority), outcome.String(), duration)
+
+ fi.done <- finalState
+ close(fi.done)
+}
+
+// inferOutcome determines the correct QueueOutcome and Error based on the cause of finalization and whether the item
+// was already admitted to a queue.
+func inferOutcome(cause error, isQueued bool) (types.QueueOutcome, error) {
+ var specificErr error
+ var outcomeIfEvicted types.QueueOutcome
+ switch {
+ case errors.Is(cause, types.ErrTTLExpired) || errors.Is(cause, context.DeadlineExceeded):
+ specificErr = types.ErrTTLExpired
+ outcomeIfEvicted = types.QueueOutcomeEvictedTTL
+ case errors.Is(cause, context.Canceled):
+ specificErr = fmt.Errorf("%w: %w", types.ErrContextCancelled, cause)
+ outcomeIfEvicted = types.QueueOutcomeEvictedContextCancelled
default:
- return false
+ // Handle other potential causes (e.g., custom context errors).
+ specificErr = cause
+ outcomeIfEvicted = types.QueueOutcomeEvictedOther
}
+
+ if isQueued {
+ // The item was in the queue when it expired/cancelled.
+ return outcomeIfEvicted, fmt.Errorf("%w: %w", types.ErrEvicted, specificErr)
+ }
+
+ // The item was not yet in the queue (e.g., buffered in enqueueChan).
+ // We treat this as a rejection, as it never formally consumed queue capacity.
+ return types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, specificErr)
}
diff --git a/pkg/epp/flowcontrol/controller/internal/item_test.go b/pkg/epp/flowcontrol/controller/internal/item_test.go
index d50aaed41..9b7b627c2 100644
--- a/pkg/epp/flowcontrol/controller/internal/item_test.go
+++ b/pkg/epp/flowcontrol/controller/internal/item_test.go
@@ -18,6 +18,7 @@ package internal
import (
"context"
+ "errors"
"testing"
"time"
@@ -28,25 +29,208 @@ import (
typesmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks"
)
-func TestItem(t *testing.T) {
+func TestFlowItem_New(t *testing.T) {
t.Parallel()
+ req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{})
- t.Run("should correctly set and get handle", func(t *testing.T) {
- t.Parallel()
- item := &flowItem{}
- handle := &typesmocks.MockQueueItemHandle{}
- item.SetHandle(handle)
- assert.Same(t, handle, item.Handle(), "Handle() should retrieve the same handle instance set by SetHandle()")
- })
-
- t.Run("should have a non-finalized state upon creation", func(t *testing.T) {
- t.Parallel()
- key := types.FlowKey{ID: "flow-a", Priority: 10}
- req := typesmocks.NewMockFlowControlRequest(100, "req-1", key, context.Background())
- item := NewItem(req, time.Minute, time.Now())
- require.NotNil(t, item, "NewItem should not return nil")
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeNotYetFinalized, outcome, "A new item's outcome should be NotYetFinalized")
- assert.NoError(t, err, "A new item should have a nil error")
- })
+ enqueueTime := time.Now()
+ item := NewItem(req, time.Minute, enqueueTime)
+
+ require.NotNil(t, item, "NewItem should not return a nil item")
+ assert.Equal(t, enqueueTime, item.EnqueueTime(), "EnqueueTime should be populated")
+ assert.Equal(t, time.Minute, item.EffectiveTTL(), "EffectiveTTL should be populated")
+ assert.Same(t, req, item.OriginalRequest(), "OriginalRequest should be populated")
+ assert.Nil(t, item.FinalState(), "a new item must not have a final state")
+ select {
+ case <-item.Done():
+ t.Fatal("Done() channel for a new item must block, but it was closed")
+ default:
+ // This is the expected path, as the channel would have blocked.
+ }
+}
+
+func TestFlowItem_Handle(t *testing.T) {
+ t.Parallel()
+ item := &FlowItem{}
+ handle := &typesmocks.MockQueueItemHandle{}
+ item.SetHandle(handle)
+ assert.Same(t, handle, item.Handle(), "Handle() must retrieve the identical handle instance set by SetHandle()")
+}
+
+func TestFlowItem_Finalize_Idempotency(t *testing.T) {
+ t.Parallel()
+ now := time.Now()
+ req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{})
+
+ testCases := []struct {
+ name string
+ firstCall func(item *FlowItem)
+ secondCall func(item *FlowItem)
+ expectedOutcome types.QueueOutcome
+ expectedErrIs error
+ }{
+ {
+ name: "Finalize then Finalize",
+ firstCall: func(item *FlowItem) {
+ item.Finalize(types.ErrTTLExpired)
+ },
+ secondCall: func(item *FlowItem) {
+ item.Finalize(context.Canceled)
+ },
+ expectedOutcome: types.QueueOutcomeRejectedOther,
+ expectedErrIs: types.ErrTTLExpired,
+ },
+ {
+ name: "Finalize then FinalizeWithOutcome",
+ firstCall: func(item *FlowItem) {
+ item.Finalize(types.ErrTTLExpired)
+ },
+ secondCall: func(item *FlowItem) {
+ item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ },
+ expectedOutcome: types.QueueOutcomeRejectedOther,
+ expectedErrIs: types.ErrTTLExpired,
+ },
+ {
+ name: "FinalizeWithOutcome then FinalizeWithOutcome",
+ firstCall: func(item *FlowItem) {
+ item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ },
+ secondCall: func(item *FlowItem) {
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedCapacity, errors.New("rejected"))
+ },
+ expectedOutcome: types.QueueOutcomeDispatched,
+ expectedErrIs: nil,
+ },
+ {
+ name: "FinalizeWithOutcome then Finalize",
+ firstCall: func(item *FlowItem) {
+ item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
+ },
+ secondCall: func(item *FlowItem) {
+ item.Finalize(types.ErrTTLExpired)
+ },
+ expectedOutcome: types.QueueOutcomeDispatched,
+ expectedErrIs: nil,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ item := NewItem(req, time.Minute, now)
+
+ // First call
+ tc.firstCall(item)
+
+ // Second call
+ tc.secondCall(item)
+
+ // Check FinalState()
+ finalState := item.FinalState()
+ require.NotNil(t, finalState, "FinalState should not be nil")
+ assert.Equal(t, tc.expectedOutcome, finalState.Outcome, "Outcome should match the first call")
+ if tc.expectedErrIs != nil {
+ assert.ErrorIs(t, finalState.Err, tc.expectedErrIs, "Error should match the first call")
+ } else {
+ assert.NoError(t, finalState.Err, "Error should be nil")
+ }
+
+ // Check Done channel
+ select {
+ case state, ok := <-item.Done():
+ require.True(t, ok, "Done channel should be readable")
+ assert.Equal(t, tc.expectedOutcome, state.Outcome, "Done channel outcome should match the first call")
+ if tc.expectedErrIs != nil {
+ assert.ErrorIs(t, state.Err, tc.expectedErrIs, "Done channel error should match the first call")
+ } else {
+ assert.NoError(t, state.Err, "Done channel error should be nil")
+ }
+ case <-time.After(50 * time.Millisecond):
+ t.Fatal("Done channel should have received the state")
+ }
+ })
+ }
+}
+
+func TestFlowItem_Finalize_InferOutcome(t *testing.T) {
+ t.Parallel()
+ now := time.Now()
+
+ testCases := []struct {
+ name string
+ cause error
+ isQueued bool
+ expectOutcome types.QueueOutcome
+ expectErrIs error
+ }{
+ {
+ name: "queued TTL expired",
+ cause: types.ErrTTLExpired,
+ isQueued: true,
+ expectOutcome: types.QueueOutcomeEvictedTTL,
+ expectErrIs: types.ErrTTLExpired,
+ },
+ {
+ name: "queued context cancelled",
+ cause: context.Canceled,
+ isQueued: true,
+ expectOutcome: types.QueueOutcomeEvictedContextCancelled,
+ expectErrIs: types.ErrContextCancelled,
+ },
+ {
+ name: "queued other error",
+ cause: errors.New("other cause"),
+ isQueued: true,
+ expectOutcome: types.QueueOutcomeEvictedOther,
+ expectErrIs: types.ErrEvicted,
+ },
+ {
+ name: "not queued TTL expired",
+ cause: types.ErrTTLExpired,
+ isQueued: false,
+ expectOutcome: types.QueueOutcomeRejectedOther,
+ expectErrIs: types.ErrTTLExpired,
+ },
+ {
+ name: "not queued context cancelled",
+ cause: context.Canceled,
+ isQueued: false,
+ expectOutcome: types.QueueOutcomeRejectedOther,
+ expectErrIs: types.ErrContextCancelled,
+ },
+ {
+ name: "nil cause queued",
+ cause: nil,
+ isQueued: true,
+ expectOutcome: types.QueueOutcomeEvictedOther,
+ expectErrIs: types.ErrEvicted,
+ },
+ {
+ name: "nil cause not queued",
+ cause: nil,
+ isQueued: false,
+ expectOutcome: types.QueueOutcomeRejectedOther,
+ expectErrIs: types.ErrRejected,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{})
+ item := NewItem(req, time.Minute, now)
+ if tc.isQueued {
+ item.SetHandle(&typesmocks.MockQueueItemHandle{})
+ }
+
+ item.Finalize(tc.cause)
+
+ finalState := item.FinalState()
+ require.NotNil(t, finalState, "FinalState should not be nil")
+ assert.Equal(t, tc.expectOutcome, finalState.Outcome, "Unexpected outcome")
+ require.Error(t, finalState.Err, "An error should be set")
+ assert.ErrorIs(t, finalState.Err, tc.expectErrIs, "Unexpected error type")
+ })
+ }
}
diff --git a/pkg/epp/flowcontrol/controller/internal/processor.go b/pkg/epp/flowcontrol/controller/internal/processor.go
index 7f9c8ee3a..2370fd646 100644
--- a/pkg/epp/flowcontrol/controller/internal/processor.go
+++ b/pkg/epp/flowcontrol/controller/internal/processor.go
@@ -26,7 +26,7 @@ import (
"time"
"github.com/go-logr/logr"
- "sigs.k8s.io/controller-runtime/pkg/log"
+ "k8s.io/utils/clock"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework"
@@ -34,124 +34,144 @@ import (
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
-const (
- // enqueueChannelBufferSize sets the size of the buffered channel that accepts incoming requests for the shard
- // processor. This buffer acts as a "shock absorber," decoupling the upstream distributor from the processor's main
- // loop and allowing the system to handle short, intense bursts of traffic without blocking the distributor.
- enqueueChannelBufferSize = 100
+// maxCleanupWorkers caps the number of concurrent workers for background cleanup tasks. This prevents a single shard
+// from overwhelming the Go scheduler with too many goroutines.
+const maxCleanupWorkers = 4
- // maxCleanupWorkers caps the number of concurrent workers for background cleanup tasks. This prevents a single shard
- // from overwhelming the Go scheduler with too many goroutines.
- maxCleanupWorkers = 4
-)
-
-var (
- // errInterFlow is a sentinel error for failures during the inter-flow dispatch phase (e.g., a
- // `framework.InterFlowDispatchPolicy` fails to select a queue).
- //
- // Strategy: When this error is encountered, the dispatch cycle aborts processing for the current priority band and
- // immediately moves to the next, promoting work conservation. A failure in one band should not halt progress in
- // others.
- errInterFlow = errors.New("inter-flow policy failure")
-
- // errIntraFlow is a sentinel error for failures *after* a specific flow's queue has been selected (e.g., a
- // `framework.IntraFlowDispatchPolicy` fails or a queue `Remove` fails).
- //
- // Strategy: When this error is encountered, the dispatch cycle aborts processing for the entire priority band for the
- // current cycle. This acts as a critical circuit breaker. A stateless inter-flow policy could otherwise repeatedly
- // select the same problematic queue in a tight loop of failures. Halting the band for one cycle prevents this.
- errIntraFlow = errors.New("intra-flow operation failure")
-)
+// ErrProcessorBusy is a sentinel error returned by the processor's Submit method indicating that the processor's.
+// internal buffer is momentarily full and cannot accept new work.
+var ErrProcessorBusy = errors.New("shard processor is busy")
-// clock defines an interface for getting the current time, allowing for dependency injection in tests.
-type clock interface {
- Now() time.Time
-}
-
-// ShardProcessor is the core worker of the `controller.FlowController`. It is paired one-to-one with a
-// `contracts.RegistryShard` instance and is responsible for all request lifecycle operations on that shard, including
-// enqueueing, dispatching, and expiry cleanup. It acts as the "data plane" worker that executes against the
-// concurrent-safe state provided by its shard.
+// ShardProcessor is the core worker of the FlowController.
+//
+// It is paired one-to-one with a RegistryShard instance and is responsible for all request lifecycle operations on that
+// shard, from the point an item is successfully submitted to it.
//
-// For a full rationale on the single-writer concurrency model, see the package-level documentation in `doc.go`.
+// # Request Lifecycle Management & Ownership
//
-// # Concurrency Guarantees and Race Conditions
+// The ShardProcessor takes ownership of a FlowItem only after it has been successfully sent to its internal enqueueChan
+// via Submit or SubmitOrBlock (i.e., when these methods return nil).
+// Once the Processor takes ownership, it is solely responsible for ensuring that item.Finalize() or
+// item.FinalizeWithOutcome() is called exactly once for that item, under all circumstances (dispatch, rejection, sweep,
+// or shutdown).
//
-// This model provides two key guarantees:
+// If Submit or SubmitOrBlock return an error, ownership remains with the caller (the Controller), which must then
+// handle the finalization.
//
-// 1. **Safe Enqueueing**: The `Run` method's goroutine has exclusive ownership of all operations that *add* items to
-// queues. This makes the "check-then-act" sequence in `enqueue` (calling `hasCapacity` then `managedQ.Add`)
-// inherently atomic from a writer's perspective, preventing capacity breaches. While the background
-// `runExpiryCleanup` goroutine can concurrently *remove* items, this is a benign race; a concurrent removal only
-// creates more available capacity, ensuring the `hasCapacity` check remains valid.
+// # Concurrency Model
//
-// 2. **Idempotent Finalization**: The primary internal race is between the main `dispatchCycle` and the background
-// `runExpiryCleanup` goroutine, which might try to finalize the same `flowItem` simultaneously. This race is
-// resolved by the `flowItem.finalize` method, which uses `sync.Once` to guarantee that only one of these goroutines
-// can set the item's final state.
+// To ensure correctness and high performance, the processor uses a single-goroutine, actor-based model. The main run
+// loop is the sole writer for all state-mutating operations. This makes complex transactions (like capacity checks)
+// inherently atomic without coarse-grained locks.
type ShardProcessor struct {
- shard contracts.RegistryShard
- dispatchFilter BandFilter
- clock clock
- expiryCleanupInterval time.Duration
- logger logr.Logger
-
- // enqueueChan is the entry point for new requests to be processed by this shard's `Run` loop.
- enqueueChan chan *flowItem
- // wg is used to wait for background tasks like expiry cleanup to complete on shutdown.
+ shard contracts.RegistryShard
+ saturationDetector contracts.SaturationDetector
+ clock clock.WithTicker
+ cleanupSweepInterval time.Duration
+ logger logr.Logger
+
+ // lifecycleCtx controls the processor's lifetime. Monitored by Submit* methods for safe shutdown.
+ lifecycleCtx context.Context
+
+ // enqueueChan is the entry point for new requests.
+ enqueueChan chan *FlowItem
+
+ // wg is used to wait for background tasks (cleanup sweep) to complete on shutdown.
wg sync.WaitGroup
isShuttingDown atomic.Bool
shutdownOnce sync.Once
}
-// NewShardProcessor creates a new `ShardProcessor` instance.
+// NewShardProcessor creates a new ShardProcessor instance.
func NewShardProcessor(
+ ctx context.Context,
shard contracts.RegistryShard,
- dispatchFilter BandFilter,
- clock clock,
- expiryCleanupInterval time.Duration,
+ saturationDetector contracts.SaturationDetector,
+ clock clock.WithTicker,
+ cleanupSweepInterval time.Duration,
+ enqueueChannelBufferSize int,
logger logr.Logger,
) *ShardProcessor {
return &ShardProcessor{
- shard: shard,
- dispatchFilter: dispatchFilter,
- clock: clock,
- expiryCleanupInterval: expiryCleanupInterval,
- logger: logger,
- // A buffered channel decouples the processor from the distributor, allowing for a fast, asynchronous handoff of new
- // requests.
- enqueueChan: make(chan *flowItem, enqueueChannelBufferSize),
+ shard: shard,
+ saturationDetector: saturationDetector,
+ clock: clock,
+ cleanupSweepInterval: cleanupSweepInterval,
+ logger: logger,
+ lifecycleCtx: ctx,
+ enqueueChan: make(chan *FlowItem, enqueueChannelBufferSize),
}
}
-// Run is the main operational loop for the shard processor. It must be run as a goroutine.
+// Submit attempts a non-blocking handoff of an item to the processor's internal enqueue channel.
+//
+// Ownership Contract:
+// - Returns nil: The item was successfully handed off.
+// The ShardProcessor takes responsibility for calling Finalize on the item.
+// - Returns error: The item was not handed off.
+// Ownership of the FlowItem remains with the caller, who is responsible for calling Finalize.
//
-// # Loop Strategy: Interleaving Enqueue and Dispatch
+// Possible errors:
+// - ErrProcessorBusy: The processor's input channel is full.
+// - types.ErrFlowControllerNotRunning: The processor is shutting down.
+func (sp *ShardProcessor) Submit(item *FlowItem) error {
+ if sp.isShuttingDown.Load() {
+ return types.ErrFlowControllerNotRunning
+ }
+ select { // The default case makes this select non-blocking.
+ case sp.enqueueChan <- item:
+ return nil // Ownership transferred.
+ case <-sp.lifecycleCtx.Done():
+ return types.ErrFlowControllerNotRunning
+ default:
+ return ErrProcessorBusy
+ }
+}
+
+// SubmitOrBlock performs a blocking handoff of an item to the processor's internal enqueue channel.
+// It waits until the item is handed off, the caller's context is cancelled, or the processor shuts down.
//
-// The loop uses a `select` statement to interleave two primary tasks:
-// 1. Accepting new requests from the `enqueueChan`.
-// 2. Attempting to dispatch existing requests from queues via `dispatchCycle`.
+// Ownership Contract:
+// - Returns nil: The item was successfully handed off.
+// The ShardProcessor takes responsibility for calling Finalize on the item.
+// - Returns error: The item was not handed off.
+// Ownership of the FlowItem remains with the caller, who is responsible for calling Finalize.
//
-// This strategy is crucial for balancing responsiveness and throughput. When a new item arrives, it is immediately
-// enqueued, and a dispatch cycle is triggered. This gives high-priority new arrivals a chance to be dispatched quickly.
-// When no new items are arriving, the loop's `default` case continuously calls `dispatchCycle` to drain the existing
-// backlog, ensuring work continues.
+// Possible errors:
+// - ctx.Err(): The provided context was cancelled or its deadline exceeded.
+// - types.ErrFlowControllerNotRunning: The processor is shutting down.
+func (sp *ShardProcessor) SubmitOrBlock(ctx context.Context, item *FlowItem) error {
+ if sp.isShuttingDown.Load() {
+ return types.ErrFlowControllerNotRunning
+ }
+
+ select { // The absence of a default case makes this call blocking.
+ case sp.enqueueChan <- item:
+ return nil // Ownership transferred.
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-sp.lifecycleCtx.Done():
+ return types.ErrFlowControllerNotRunning
+ }
+}
+
+// Run is the main operational loop for the shard processor. It must be run as a goroutine.
+// It uses a `select` statement to interleave accepting new requests with dispatching existing ones, balancing
+// responsiveness with throughput.
func (sp *ShardProcessor) Run(ctx context.Context) {
sp.logger.V(logutil.DEFAULT).Info("Shard processor run loop starting.")
defer sp.logger.V(logutil.DEFAULT).Info("Shard processor run loop stopped.")
sp.wg.Add(1)
- go sp.runExpiryCleanup(ctx)
+ go sp.runCleanupSweep(ctx)
// This is the main worker loop. It continuously processes incoming requests and dispatches queued requests until the
// context is cancelled. The `select` statement has three cases:
//
// 1. Context Cancellation: The highest priority is shutting down. If the context's `Done` channel is closed, the
// loop will drain all queues and exit. This is the primary exit condition.
- //
// 2. New Item Arrival: If an item is available on `enqueueChan`, it will be processed. This ensures that the
// processor is responsive to new work.
- //
// 3. Default (Dispatch): If neither of the above cases is ready, the `default` case executes, ensuring the loop is
// non-blocking. It continuously attempts to dispatch items from the existing backlog, preventing starvation and
// ensuring queues are drained.
@@ -175,8 +195,9 @@ func (sp *ShardProcessor) Run(ctx context.Context) {
sp.enqueue(item)
sp.dispatchCycle(ctx)
default:
+ // If no new items are arriving, continuously try to dispatch from the backlog.
if !sp.dispatchCycle(ctx) {
- // If no work was done, yield to other goroutines to prevent a tight, busy-loop when idle, but allow for
+ // If no work was done, yield to the scheduler to prevent a tight, busy-loop when idle, while still allowing for
// immediate rescheduling.
runtime.Gosched()
}
@@ -184,91 +205,66 @@ func (sp *ShardProcessor) Run(ctx context.Context) {
}
}
-// Enqueue sends a new flow item to the processor's internal channel for asynchronous processing by its main `Run` loop.
-// If the processor is shutting down, it immediately finalizes the item with a shutdown error.
-func (sp *ShardProcessor) Enqueue(item *flowItem) {
- if sp.isShuttingDown.Load() {
- item.finalize(types.QueueOutcomeRejectedOther,
- fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerShutdown))
- return
- }
- sp.enqueueChan <- item
-}
-
-// enqueue is the internal implementation for adding a new item to a managed queue. It is always run from the single
-// main `Run` goroutine, making its "check-then-act" logic for capacity safe.
-func (sp *ShardProcessor) enqueue(item *flowItem) {
+// enqueue processes an item received from the enqueueChan.
+// It handles capacity checks, checks for external finalization, and either admits the item to a queue or rejects it.
+func (sp *ShardProcessor) enqueue(item *FlowItem) {
req := item.OriginalRequest()
key := req.FlowKey()
- logger := log.FromContext(req.Context()).WithName("enqueue").WithValues(
- "flowKey", key,
- "flowID", key.ID,
- "priority", key.Priority,
- "reqID", req.ID(),
- "reqByteSize", req.ByteSize(),
- )
+ // --- Optimistic External Finalization Check ---
+ // Check if the item was finalized by the Controller (due to TTL/cancellation) while it was buffered in enqueueChan.
+ // This is an optimistic check to avoid unnecessary processing on items already considered dead.
+ // The ultimate guarantee of cleanup for any races is the runCleanupSweep mechanism.
+ if finalState := item.FinalState(); finalState != nil {
+ sp.logger.V(logutil.TRACE).Info("Item finalized externally before processing, discarding.",
+ "outcome", finalState.Outcome, "err", finalState.Err, "flowKey", key, "reqID", req.ID())
+ return
+ }
+ // --- Configuration Validation ---
managedQ, err := sp.shard.ManagedQueue(key)
if err != nil {
finalErr := fmt.Errorf("configuration error: failed to get queue for flow key %s: %w", key, err)
- logger.Error(finalErr, "Rejecting item.")
- item.finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr))
+ sp.logger.Error(finalErr, "Rejecting item.", "flowKey", key, "reqID", req.ID())
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr))
return
}
band, err := sp.shard.PriorityBandAccessor(key.Priority)
if err != nil {
finalErr := fmt.Errorf("configuration error: failed to get priority band for priority %d: %w", key.Priority, err)
- logger.Error(finalErr, "Rejecting item.")
- item.finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr))
+ sp.logger.Error(finalErr, "Rejecting item.", "flowKey", key, "reqID", req.ID())
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr))
return
}
- logger = logger.WithValues("priorityName", band.PriorityName())
+ // --- Capacity Check ---
+ // This check is safe because it is performed by the single-writer Run goroutine.
if !sp.hasCapacity(key.Priority, req.ByteSize()) {
- // This is an expected outcome, not a system error. Log at the default level with rich context.
- stats := sp.shard.Stats()
- bandStats := stats.PerPriorityBandStats[key.Priority]
- logger.V(logutil.DEFAULT).Info("Rejecting request, queue at capacity",
- "outcome", types.QueueOutcomeRejectedCapacity,
- "shardTotalBytes", stats.TotalByteSize,
- "shardCapacityBytes", stats.TotalCapacityBytes,
- "bandTotalBytes", bandStats.ByteSize,
- "bandCapacityBytes", bandStats.CapacityBytes,
- )
- item.finalize(types.QueueOutcomeRejectedCapacity, fmt.Errorf("%w: %w", types.ErrRejected, types.ErrQueueAtCapacity))
+ sp.logger.V(logutil.DEBUG).Info("Rejecting request, queue at capacity",
+ "flowKey", key, "reqID", req.ID(), "priorityName", band.PriorityName(), "reqByteSize", req.ByteSize())
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedCapacity, fmt.Errorf("%w: %w",
+ types.ErrRejected, types.ErrQueueAtCapacity))
return
}
- // This is an optimistic check to prevent a needless add/remove cycle for an item that was finalized (e.g., context
- // cancelled) during the handoff to this processor. A race condition still exists where an item can be finalized
- // after this check but before the `Add` call completes.
- //
- // This is considered acceptable because:
- // 1. The race window is extremely small.
- // 2. The background `runExpiryCleanup` goroutine acts as the ultimate guarantor of correctness, as it will
- // eventually find and evict any finalized item that slips through this check and is added to a queue.
- if item.isFinalized() {
- outcome, err := item.FinalState()
- logger.V(logutil.VERBOSE).Info("Item finalized before adding to queue, ignoring.", "outcome", outcome, "err", err)
- return
- }
-
- // This is the point of commitment. After this call, the item is officially in the queue and is the responsibility of
- // the dispatch or cleanup loops to finalize.
+ // --- Commitment Point ---
+ // The item is admitted. The ManagedQueue.Add implementation is responsible for calling item.SetHandle() atomically.
if err := managedQ.Add(item); err != nil {
finalErr := fmt.Errorf("failed to add item to queue for flow key %s: %w", key, err)
- logger.Error(finalErr, "Rejecting item.")
- item.finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr))
+ sp.logger.Error(finalErr, "Rejecting item post-admission.",
+ "flowKey", key, "reqID", req.ID(), "priorityName", band.PriorityName())
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr))
return
}
- logger.V(logutil.TRACE).Info("Item enqueued.")
+ sp.logger.V(logutil.TRACE).Info("Item enqueued.",
+ "flowKey", key, "reqID", req.ID(), "priorityName", band.PriorityName())
}
-// hasCapacity checks if the shard and the specific priority band have enough capacity to accommodate an item of a given
-// size.
-func (sp *ShardProcessor) hasCapacity(priority uint, itemByteSize uint64) bool {
+// hasCapacity checks if the shard and the specific priority band have enough capacity.
+// This check reflects actual resource utilization, including "zombie" items (finalized but unswept), to prevent
+// physical resource overcommitment.
+func (sp *ShardProcessor) hasCapacity(priority int, itemByteSize uint64) bool {
if itemByteSize == 0 {
return true
}
@@ -278,311 +274,201 @@ func (sp *ShardProcessor) hasCapacity(priority uint, itemByteSize uint64) bool {
}
bandStats, ok := stats.PerPriorityBandStats[priority]
if !ok {
- // This should not happen if the registry is consistent, but we fail closed just in case.
- return false
+ return false // Fail closed if configuration is inconsistent.
}
return bandStats.ByteSize+itemByteSize <= bandStats.CapacityBytes
}
-// dispatchCycle attempts to dispatch a single item by iterating through all priority bands from highest to lowest.
+// dispatchCycle attempts to dispatch a single item by iterating through priority bands from highest to lowest.
// It applies the configured policies for each band to select an item and then attempts to dispatch it.
// It returns true if an item was successfully dispatched, and false otherwise.
+// It enforces Head-of-Line (HoL) blocking if the selected item is saturated.
//
-// # Error Handling Philosophy
-//
-// The engine employs a robust, two-tiered error handling strategy to isolate failures and maximize system availability.
-// This is managed via the `errInterFlow` and `errIntraFlow` sentinel errors.
-//
-// - Inter-Flow Failures: If a failure occurs while selecting a flow (e.g., the `InterFlowDispatchPolicy` fails), the
-// processor aborts the *current priority band* and immediately moves to the next one. This promotes work
-// conservation, ensuring a single misconfigured band does not halt progress for the entire system.
+// # Work Conservation and Head-of-Line (HoL) Blocking
//
-// - Intra-Flow Failures: If a failure occurs *after* a flow has been selected (e.g., the `IntraFlowDispatchPolicy`
-// fails), the processor aborts the *entire priority band* for the current cycle. This is a critical circuit
-// breaker. An inter-flow policy that is not stateful with respect to past failures could otherwise repeatedly
-// select the same problematic queue, causing a tight loop of failures. Halting the band for one cycle prevents
-// this.
+// The cycle attempts to be work-conserving by skipping bands where selection fails.
+// However, if a selected item is saturated (cannot be scheduled), the cycle stops immediately. This enforces HoL
+// blocking to respect the policy's decision and prevent priority inversion, where dispatching lower-priority work might
+// exacerbate the saturation affecting the high-priority item.
func (sp *ShardProcessor) dispatchCycle(ctx context.Context) bool {
- baseLogger := sp.logger.WithName("dispatchCycle")
-
- // FUTURE EXTENSION POINT: The iteration over priority bands is currently a simple, strict-priority loop.
- // This could be abstracted into a third policy tier (e.g., an `InterBandDispatchPolicy`) if more complex scheduling
- // between bands, such as Weighted Fair Queuing (WFQ), is ever required. For now, strict priority is sufficient.
for _, priority := range sp.shard.AllOrderedPriorityLevels() {
originalBand, err := sp.shard.PriorityBandAccessor(priority)
if err != nil {
- baseLogger.Error(err, "Failed to get PriorityBandAccessor, skipping band", "priority", priority)
+ sp.logger.Error(err, "Failed to get PriorityBandAccessor, skipping band", "priority", priority)
continue
}
- logger := baseLogger.WithValues("priority", priority, "priorityName", originalBand.PriorityName())
- // Apply the configured filter to get a view of only the dispatchable flows.
- dispatchableBand, shouldPause := sp.dispatchFilter(ctx, originalBand, logger)
- if shouldPause {
- return false // A global gate told us to stop the entire cycle.
- }
- if dispatchableBand == nil {
- // A nil return from the filter indicates the fast path: no filtering was needed.
- dispatchableBand = originalBand
- }
-
- // Pass the (potentially filtered) band to the policies.
- item, err := sp.selectItem(dispatchableBand, logger)
+ item, err := sp.selectItem(originalBand)
if err != nil {
- // The error handling strategy depends on the type of failure (inter- vs. intra-flow).
- if errors.Is(err, errIntraFlow) {
- logger.Error(err, "Intra-flow policy failure, skipping priority band for this cycle")
- } else {
- logger.Error(err, "Inter-flow policy or configuration failure, skipping priority band for this cycle")
- }
- continue
+ sp.logger.Error(err, "Failed to select item, skipping priority band for this cycle",
+ "priority", priority, "priorityName", originalBand.PriorityName())
+ continue // Continue to the next band to maximize work conservation.
}
if item == nil {
- // This is the common case where a priority band has no items to dispatch.
- logger.V(logutil.TRACE).Info("No item selected by dispatch policies, skipping band")
continue
}
- logger = logger.WithValues(
- "flowKey", item.OriginalRequest().FlowKey(),
- "flowID", item.OriginalRequest().FlowKey().ID,
- "flowPriority", item.OriginalRequest().FlowKey().Priority,
- "reqID", item.OriginalRequest().ID(),
- "reqByteSize", item.OriginalRequest().ByteSize())
-
- if err := sp.dispatchItem(item, logger); err != nil {
- // All errors from dispatchItem are considered intra-flow and unrecoverable for this band in this cycle.
- logger.Error(err, "Failed to dispatch item, skipping priority band for this cycle")
- continue
+
+ // --- Viability Check (Saturation/HoL Blocking) ---
+ req := item.OriginalRequest()
+ candidatePods := req.CandidatePodsForScheduling()
+ if sp.saturationDetector.IsSaturated(ctx, candidatePods) {
+ sp.logger.V(logutil.DEBUG).Info("Policy's chosen item is saturated; enforcing HoL blocking.",
+ "flowKey", req.FlowKey(), "reqID", req.ID(), "priorityName", originalBand.PriorityName())
+ // Stop the dispatch cycle entirely to respect strict policy decision and prevent priority inversion where
+ // lower-priority work might exacerbate the saturation affecting high-priority work.
+ return false
+ }
+
+ // --- Dispatch ---
+ if err := sp.dispatchItem(item); err != nil {
+ sp.logger.Error(err, "Failed to dispatch item, skipping priority band for this cycle",
+ "flowKey", req.FlowKey(), "reqID", req.ID(), "priorityName", originalBand.PriorityName())
+ continue // Continue to the next band to maximize work conservation.
}
- // A successful dispatch occurred, so we return true to signal that work was done.
return true
}
- // No items were dispatched in this cycle across all priority bands.
return false
}
-// selectItem applies the configured inter- and intra-flow dispatch policies to select a single item from a priority
-// band.
-func (sp *ShardProcessor) selectItem(
- band framework.PriorityBandAccessor,
- logger logr.Logger,
-) (types.QueueItemAccessor, error) {
+// selectItem applies the configured inter- and intra-flow dispatch policies to select a single item.
+func (sp *ShardProcessor) selectItem(band framework.PriorityBandAccessor) (types.QueueItemAccessor, error) {
interP, err := sp.shard.InterFlowDispatchPolicy(band.Priority())
if err != nil {
- return nil, fmt.Errorf("%w: could not get InterFlowDispatchPolicy: %w", errInterFlow, err)
+ return nil, fmt.Errorf("could not get InterFlowDispatchPolicy: %w", err)
}
queue, err := interP.SelectQueue(band)
if err != nil {
- return nil, fmt.Errorf("%w: InterFlowDispatchPolicy %q failed to select queue: %w",
- errInterFlow, interP.Name(), err)
+ return nil, fmt.Errorf("InterFlowDispatchPolicy %q failed to select queue: %w", interP.Name(), err)
}
if queue == nil {
- logger.V(logutil.TRACE).Info("No queue selected by InterFlowDispatchPolicy")
return nil, nil
}
key := queue.FlowKey()
- logger = logger.WithValues(
- "selectedFlowKey", key,
- "selectedFlowID", key.ID,
- "selectedFlowPriority", key.Priority)
intraP, err := sp.shard.IntraFlowDispatchPolicy(key)
if err != nil {
- // This is an intra-flow failure because we have already successfully selected a queue.
- return nil, fmt.Errorf("%w: could not get IntraFlowDispatchPolicy for flow %q: %w", errIntraFlow, key, err)
+ return nil, fmt.Errorf("could not get IntraFlowDispatchPolicy for flow %s: %w", key, err)
}
item, err := intraP.SelectItem(queue)
if err != nil {
- return nil, fmt.Errorf("%w: IntraFlowDispatchPolicy %q failed to select item for flow %q: %w",
- errIntraFlow, intraP.Name(), key, err)
- }
- if item == nil {
- logger.V(logutil.TRACE).Info("No item selected by IntraFlowDispatchPolicy")
- return nil, nil
+ return nil, fmt.Errorf("IntraFlowDispatchPolicy %q failed to select item for flow %s: %w", intraP.Name(), key, err)
}
return item, nil
}
-// dispatchItem handles the final steps of dispatching an item after it has been selected by policies. This includes
-// removing it from its queue, checking for last-minute expiry, and finalizing its outcome.
-func (sp *ShardProcessor) dispatchItem(itemAcc types.QueueItemAccessor, logger logr.Logger) error {
- logger = logger.WithName("dispatchItem")
-
+// dispatchItem handles the final steps of dispatching an item: removing it from the queue and finalizing its outcome.
+func (sp *ShardProcessor) dispatchItem(itemAcc types.QueueItemAccessor) error {
req := itemAcc.OriginalRequest()
- // We must look up the queue by its specific priority, as a flow might have draining queues at other levels.
- managedQ, err := sp.shard.ManagedQueue(req.FlowKey())
+ key := req.FlowKey()
+ managedQ, err := sp.shard.ManagedQueue(key)
if err != nil {
- return fmt.Errorf("%w: failed to get ManagedQueue for flow %q: %w", errIntraFlow, req.FlowKey(), err)
+ return fmt.Errorf("failed to get ManagedQueue for flow %s: %w", key, err)
}
- // The core mutation: remove the item from the queue.
removedItemAcc, err := managedQ.Remove(itemAcc.Handle())
if err != nil {
- // This can happen benignly if the item was already removed by the expiry cleanup loop between the time it was
- // selected by the policy and the time this function is called.
- logger.V(logutil.VERBOSE).Info("Item already removed from queue, likely by expiry cleanup", "err", err)
- return fmt.Errorf("%w: failed to remove item %q from queue for flow %q: %w",
- errIntraFlow, req.ID(), req.FlowKey(), err)
+ // This happens benignly if the item was already removed by the cleanup sweep loop.
+ // We log it at a low level for visibility but return nil so the dispatch cycle proceeds.
+ sp.logger.V(logutil.DEBUG).Info("Failed to remove item during dispatch (likely already finalized and swept).",
+ "flowKey", key, "reqID", req.ID(), "error", err)
+ return nil
}
- removedItem, ok := removedItemAcc.(*flowItem)
- if !ok {
- // This indicates a severe logic error where a queue returns an item of an unexpected type. This violates a
- // core system invariant: all items managed by the processor must be of type *flowItem. This is an unrecoverable
- // state for this shard.
- unexpectedItemErr := fmt.Errorf("%w: internal error: item %q of type %T is not a *flowItem",
- errIntraFlow, removedItemAcc.OriginalRequest().ID(), removedItemAcc)
- panic(unexpectedItemErr)
- }
-
- // Final check for expiry/cancellation right before dispatch.
- isExpired, outcome, expiryErr := checkItemExpiry(removedItem, sp.clock.Now())
- if isExpired {
- // Ensure we always have a non-nil error to wrap for consistent logging and error handling.
- finalErr := expiryErr
- if finalErr == nil {
- finalErr = errors.New("item finalized before dispatch")
- }
- logger.V(logutil.VERBOSE).Info("Item expired at time of dispatch, evicting", "outcome", outcome,
- "err", finalErr)
- removedItem.finalize(outcome, fmt.Errorf("%w: %w", types.ErrEvicted, finalErr))
- // Return an error to signal that the dispatch did not succeed.
- return fmt.Errorf("%w: item %q expired before dispatch: %w", errIntraFlow, req.ID(), finalErr)
- }
-
- // Finalize the item as dispatched.
- removedItem.finalize(types.QueueOutcomeDispatched, nil)
- logger.V(logutil.TRACE).Info("Item dispatched.")
+ removedItem := removedItemAcc.(*FlowItem)
+ sp.logger.V(logutil.TRACE).Info("Item dispatched.", "flowKey", req.FlowKey(), "reqID", req.ID())
+ removedItem.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
return nil
}
-// checkItemExpiry checks if an item has been cancelled (via its context) or has exceeded its TTL. It returns true if
-// the item is expired, along with the corresponding outcome and error.
-//
-// This function provides "defense in depth" against race conditions. It is the authoritative check that is called from
-// multiple locations (the dispatch loop and the cleanup loop) to determine if an item should be evicted. Its first
-// action is to check if the item has *already* been finalized by a competing goroutine, ensuring that the final outcome
-// is decided exactly once.
-func checkItemExpiry(
- itemAcc types.QueueItemAccessor,
- now time.Time,
-) (isExpired bool, outcome types.QueueOutcome, err error) {
- item, ok := itemAcc.(*flowItem)
- if !ok {
- // This indicates a severe logic error where a queue returns an item of an unexpected type. This violates a
- // core system invariant: all items managed by the processor must be of type *flowItem. This is an unrecoverable
- // state for this shard.
- unexpectedItemErr := fmt.Errorf("internal error: item %q of type %T is not a *flowItem",
- itemAcc.OriginalRequest().ID(), itemAcc)
- panic(unexpectedItemErr)
- }
-
- // This check is a critical defense against race conditions. If another goroutine (e.g., the cleanup loop) has
- // already finalized this item, we must respect that outcome.
- if item.isFinalized() {
- outcome, err := item.FinalState()
- return true, outcome, err
- }
-
- // Check if the request's context has been cancelled.
- if ctxErr := item.OriginalRequest().Context().Err(); ctxErr != nil {
- return true, types.QueueOutcomeEvictedContextCancelled, fmt.Errorf("%w: %w", types.ErrContextCancelled, ctxErr)
- }
-
- // Check if the item has outlived its TTL.
- if item.EffectiveTTL() > 0 && now.Sub(item.EnqueueTime()) > item.EffectiveTTL() {
- return true, types.QueueOutcomeEvictedTTL, types.ErrTTLExpired
- }
-
- return false, types.QueueOutcomeNotYetFinalized, nil
-}
-
-// runExpiryCleanup starts a background goroutine that periodically scans all queues on the shard for expired items.
-func (sp *ShardProcessor) runExpiryCleanup(ctx context.Context) {
+// runCleanupSweep starts a background goroutine that periodically scans all queues for externally finalized items
+// ("zombie" items) and removes them in batches.
+func (sp *ShardProcessor) runCleanupSweep(ctx context.Context) {
defer sp.wg.Done()
- logger := sp.logger.WithName("runExpiryCleanup")
- logger.V(logutil.DEFAULT).Info("Shard expiry cleanup goroutine starting.")
- defer logger.V(logutil.DEFAULT).Info("Shard expiry cleanup goroutine stopped.")
+ logger := sp.logger.WithName("runCleanupSweep")
+ logger.V(logutil.DEFAULT).Info("Shard cleanup sweep goroutine starting.")
+ defer logger.V(logutil.DEFAULT).Info("Shard cleanup sweep goroutine stopped.")
- ticker := time.NewTicker(sp.expiryCleanupInterval)
+ ticker := sp.clock.NewTicker(sp.cleanupSweepInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
- case now := <-ticker.C:
- sp.cleanupExpired(now)
+ case <-ticker.C():
+ sp.sweepFinalizedItems()
}
}
}
-// cleanupExpired performs a single scan of all queues on the shard, removing and finalizing any items that have
-// expired due to TTL or context cancellation.
-func (sp *ShardProcessor) cleanupExpired(now time.Time) {
- processFn := func(managedQ contracts.ManagedQueue, queueLogger logr.Logger) {
- // This predicate identifies items to be removed by the Cleanup call.
- predicate := func(item types.QueueItemAccessor) bool {
- isExpired, _, _ := checkItemExpiry(item, now)
- return isExpired
+// sweepFinalizedItems performs a single scan of all queues, removing finalized items in batch and releasing their
+// memory.
+func (sp *ShardProcessor) sweepFinalizedItems() {
+ processFn := func(managedQ contracts.ManagedQueue, logger logr.Logger) {
+ key := managedQ.FlowQueueAccessor().FlowKey()
+ predicate := func(itemAcc types.QueueItemAccessor) bool {
+ return itemAcc.(*FlowItem).FinalState() != nil
}
-
removedItems, err := managedQ.Cleanup(predicate)
if err != nil {
- queueLogger.Error(err, "Error during ManagedQueue Cleanup")
+ logger.Error(err, "Error during ManagedQueue Cleanup", "flowKey", key)
}
-
- // Finalize all the items that were removed.
- sp.finalizeExpiredItems(removedItems, now, queueLogger)
+ logger.V(logutil.DEBUG).Info("Swept finalized items and released capacity.",
+ "flowKey", key, "count", len(removedItems))
}
- sp.processAllQueuesConcurrently("cleanupExpired", processFn)
+ sp.processAllQueuesConcurrently("sweepFinalizedItems", processFn)
}
-// shutdown handles the graceful termination of the processor. It uses sync.Once to guarantee that the shutdown logic is
-// executed exactly once, regardless of whether it's triggered by context cancellation or the closing of the enqueue
-// channel.
+// shutdown handles the graceful termination of the processor, ensuring all pending items (in channel and queues) are
+// Finalized.
func (sp *ShardProcessor) shutdown() {
sp.shutdownOnce.Do(func() {
- // Set the atomic bool so that any new calls to Enqueue will fail fast.
sp.isShuttingDown.Store(true)
sp.logger.V(logutil.DEFAULT).Info("Shard processor shutting down.")
- // Drain the channel BEFORE closing it. This prevents a panic from any goroutine that is currently blocked trying to
- // send to the channel. We read until it's empty.
- DrainLoop:
+ DrainLoop: // Drain the enqueueChan to finalize buffered items.
for {
select {
case item := <-sp.enqueueChan:
- if item == nil { // This is a safeguard against logic errors in the distributor.
+ if item == nil {
continue
}
- item.finalize(types.QueueOutcomeRejectedOther,
- fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerShutdown))
+ // Finalize buffered items.
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther,
+ fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerNotRunning))
default:
- // The channel is empty, we can now safely close it.
break DrainLoop
}
}
- close(sp.enqueueChan)
-
- // Evict all remaining items from the queues.
+ // We do not close enqueueChan because external goroutines (Controller) send on it.
+ // The channel will be garbage collected when the processor terminates.
sp.evictAll()
})
}
-// evictAll drains all queues on the shard and finalizes every item with a shutdown error. This is called when the
-// processor is shutting down to ensure no requests are left in a pending state.
+// evictAll drains all queues on the shard, finalizes every item, and releases their memory.
func (sp *ShardProcessor) evictAll() {
- processFn := func(managedQ contracts.ManagedQueue, queueLogger logr.Logger) {
+ processFn := func(managedQ contracts.ManagedQueue, logger logr.Logger) {
+ key := managedQ.FlowQueueAccessor().FlowKey()
removedItems, err := managedQ.Drain()
if err != nil {
- queueLogger.Error(err, "Error during ManagedQueue Drain")
+ logger.Error(err, "Error during ManagedQueue Drain", "flowKey", key)
}
- // Finalize all the items that were removed.
- getOutcome := func(_ types.QueueItemAccessor) (types.QueueOutcome, error) {
- return types.QueueOutcomeEvictedOther, fmt.Errorf("%w: %w", types.ErrEvicted, types.ErrFlowControllerShutdown)
+ outcome := types.QueueOutcomeEvictedOther
+ errShutdown := fmt.Errorf("%w: %w", types.ErrEvicted, types.ErrFlowControllerNotRunning)
+ for _, i := range removedItems {
+ item, ok := i.(*FlowItem)
+ if !ok {
+ logger.Error(fmt.Errorf("internal error: unexpected type %T", i),
+ "Panic condition detected during shutdown", "flowKey", key)
+ continue
+ }
+
+ // Finalization is idempotent; safe to call even if already finalized externally.
+ item.FinalizeWithOutcome(outcome, errShutdown)
+ logger.V(logutil.TRACE).Info("Item evicted during shutdown.",
+ "flowKey", key, "reqID", item.OriginalRequest().ID())
}
- sp.finalizeItems(removedItems, queueLogger, getOutcome)
}
sp.processAllQueuesConcurrently("evictAll", processFn)
}
@@ -650,38 +536,3 @@ func (sp *ShardProcessor) processAllQueuesConcurrently(
close(tasks) // Close the channel to signal workers to exit.
wg.Wait() // Wait for all workers to finish.
}
-
-// finalizeItems is a helper to iterate over a slice of items, safely cast them, and finalize them with an outcome
-// determined by the `getOutcome` function.
-func (sp *ShardProcessor) finalizeItems(
- items []types.QueueItemAccessor,
- logger logr.Logger,
- getOutcome func(item types.QueueItemAccessor) (types.QueueOutcome, error),
-) {
- for _, i := range items {
- item, ok := i.(*flowItem)
- if !ok {
- unexpectedItemErr := fmt.Errorf("internal error: item %q of type %T is not a *flowItem",
- i.OriginalRequest().ID(), i)
- logger.Error(unexpectedItemErr, "Panic condition detected during finalization", "item", i)
- continue
- }
-
- outcome, err := getOutcome(i)
- item.finalize(outcome, err)
- logger.V(logutil.TRACE).Info("Item finalized", "reqID", item.OriginalRequest().ID(),
- "outcome", outcome, "err", err)
- }
-}
-
-// finalizeExpiredItems is a specialized version of finalizeItems for items that are known to be expired. It determines
-// the precise reason for expiry and finalizes the item accordingly.
-func (sp *ShardProcessor) finalizeExpiredItems(items []types.QueueItemAccessor, now time.Time, logger logr.Logger) {
- getOutcome := func(item types.QueueItemAccessor) (types.QueueOutcome, error) {
- // We don't need the `isExpired` boolean here because we know it's true, but this function conveniently returns the
- // precise outcome and error.
- _, outcome, expiryErr := checkItemExpiry(item, now)
- return outcome, fmt.Errorf("%w: %w", types.ErrEvicted, expiryErr)
- }
- sp.finalizeItems(items, logger, getOutcome)
-}
diff --git a/pkg/epp/flowcontrol/controller/internal/processor_test.go b/pkg/epp/flowcontrol/controller/internal/processor_test.go
index 67657a9a4..73fc5b13d 100644
--- a/pkg/epp/flowcontrol/controller/internal/processor_test.go
+++ b/pkg/epp/flowcontrol/controller/internal/processor_test.go
@@ -14,29 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
-//
-// A Note on the Testing Strategy for `ShardProcessor`
-//
-// The `ShardProcessor` is a complex concurrent orchestrator. Testing it with concrete implementations would lead to
-// flaky, non-deterministic tests. Therefore, we use a high-fidelity `testHarness` with stateful mocks to enable
-// reliable and deterministic testing. This is a deliberate and necessary choice for several key reasons:
-//
-// 1. Deterministic Race Simulation: The harness allows us to pause mock execution at critical moments, making it
-// possible to deterministically simulate and verify the processor's behavior during race conditions (e.g., the
-// dispatch vs. expiry race). This is impossible with concrete implementations without resorting to unreliable
-// sleeps.
-//
-// 2. Failure Mode Simulation: We can trigger specific, on-demand errors from dependencies to verify the processor's
-// resilience and complex error-handling logic (e.g., the `errIntraFlow` circuit breaker).
-//
-// 3. Interaction and Isolation Testing: Mocks allow us to isolate the `ShardProcessor` from its dependencies. This
-// ensures that tests are verifying the processor's orchestration logic (i.e., that it calls its dependencies
-// correctly) and are not affected by confounding bugs in those dependencies.
-//
-// In summary, this is a prerequisite for reliably testing a concurrent engine, not just a simple data
-// structure.
-//
-
package internal
import (
@@ -44,7 +21,7 @@ import (
"errors"
"fmt"
"os"
- "slices"
+ "sort"
"sync"
"sync/atomic"
"testing"
@@ -53,9 +30,11 @@ import (
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ testclock "k8s.io/utils/clock/testing"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts/mocks"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework"
@@ -79,28 +58,6 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
-// mockClock allows for controlling time in tests.
-type mockClock struct {
- mu sync.Mutex
- currentTime time.Time
-}
-
-func newMockClock() *mockClock {
- return &mockClock{currentTime: time.Now()}
-}
-
-func (c *mockClock) Now() time.Time {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.currentTime
-}
-
-func (c *mockClock) Advance(d time.Duration) {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.currentTime = c.currentTime.Add(d)
-}
-
// testHarness provides a unified, mock-based testing environment for the ShardProcessor. It centralizes all mock state
// and provides helper methods for setting up tests and managing the processor's lifecycle.
type testHarness struct {
@@ -114,15 +71,16 @@ type testHarness struct {
startSignal chan struct{}
// Core components under test
- processor *ShardProcessor
- mockClock *mockClock
- logger logr.Logger
+ processor *ShardProcessor
+ clock *testclock.FakeClock
+ logger logr.Logger
+ saturationDetector *mocks.MockSaturationDetector
// --- Centralized Mock State ---
// The harness's mutex protects the single source of truth for all mock state.
mu sync.Mutex
queues map[types.FlowKey]*mocks.MockManagedQueue
- priorityFlows map[uint][]types.FlowKey // Key: `priority`
+ priorityFlows map[int][]types.FlowKey // Key: `priority`
// Customizable policy logic for tests to override.
interFlowPolicySelectQueue func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error)
@@ -133,14 +91,16 @@ type testHarness struct {
func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarness {
t.Helper()
h := &testHarness{
- t: t,
- MockRegistryShard: &mocks.MockRegistryShard{},
- mockClock: newMockClock(),
- logger: logr.Discard(),
- startSignal: make(chan struct{}),
- queues: make(map[types.FlowKey]*mocks.MockManagedQueue),
- priorityFlows: make(map[uint][]types.FlowKey),
+ t: t,
+ MockRegistryShard: &mocks.MockRegistryShard{},
+ clock: testclock.NewFakeClock(time.Now()),
+ logger: logr.Discard(),
+ saturationDetector: &mocks.MockSaturationDetector{},
+ startSignal: make(chan struct{}),
+ queues: make(map[types.FlowKey]*mocks.MockManagedQueue),
+ priorityFlows: make(map[int][]types.FlowKey),
}
+ h.ctx, h.cancel = context.WithCancel(context.Background())
// Wire up the harness to provide the mock implementations for the shard's dependencies.
h.ManagedQueueFunc = h.managedQueue
@@ -153,22 +113,24 @@ func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarn
h.StatsFunc = func() contracts.ShardStats {
return contracts.ShardStats{
TotalCapacityBytes: 1e9,
- PerPriorityBandStats: map[uint]contracts.PriorityBandStats{
+ PerPriorityBandStats: map[int]contracts.PriorityBandStats{
testFlow.Priority: {CapacityBytes: 1e9},
},
}
}
- // Use a default pass-through filter.
- filter := func(
- ctx context.Context,
- band framework.PriorityBandAccessor,
- logger logr.Logger,
- ) (framework.PriorityBandAccessor, bool) {
- return nil, false
- }
- h.processor = NewShardProcessor(h, filter, h.mockClock, expiryCleanupInterval, h.logger)
+ h.processor = NewShardProcessor(
+ h.ctx,
+ h,
+ h.saturationDetector,
+ h.clock,
+ expiryCleanupInterval,
+ 100,
+ h.logger)
require.NotNil(t, h.processor, "NewShardProcessor should not return nil")
+
+ t.Cleanup(func() { h.Stop() })
+
return h
}
@@ -202,23 +164,22 @@ func (h *testHarness) Stop() {
}
// waitForFinalization blocks until an item is finalized or a timeout is reached.
-func (h *testHarness) waitForFinalization(item *flowItem) (types.QueueOutcome, error) {
+func (h *testHarness) waitForFinalization(item *FlowItem) (types.QueueOutcome, error) {
h.t.Helper()
select {
- case <-item.Done():
- return item.FinalState()
+ case finalState := <-item.Done():
+ return finalState.Outcome, finalState.Err
case <-time.After(testWaitTimeout):
h.t.Fatalf("Timed out waiting for item %q to be finalized", item.OriginalRequest().ID())
return types.QueueOutcomeNotYetFinalized, nil
}
}
-// newTestItem creates a new flowItem for testing purposes.
-func (h *testHarness) newTestItem(id string, key types.FlowKey, ttl time.Duration) *flowItem {
+// newTestItem creates a new FlowItem for testing purposes.
+func (h *testHarness) newTestItem(id string, key types.FlowKey, ttl time.Duration) *FlowItem {
h.t.Helper()
- ctx := log.IntoContext(context.Background(), h.logger)
- req := typesmocks.NewMockFlowControlRequest(100, id, key, ctx)
- return NewItem(req, ttl, h.mockClock.Now())
+ req := typesmocks.NewMockFlowControlRequest(100, id, key)
+ return NewItem(req, ttl, h.clock.Now())
}
// addQueue centrally registers a new mock queue for a given flow, ensuring all harness components are aware of it.
@@ -226,13 +187,9 @@ func (h *testHarness) addQueue(key types.FlowKey) *mocks.MockManagedQueue {
h.t.Helper()
h.mu.Lock()
defer h.mu.Unlock()
-
mockQueue := &mocks.MockManagedQueue{FlowKeyV: key}
h.queues[key] = mockQueue
-
- // Add the key to the correct priority band, creating the band if needed.
h.priorityFlows[key.Priority] = append(h.priorityFlows[key.Priority], key)
-
return mockQueue
}
@@ -249,20 +206,23 @@ func (h *testHarness) managedQueue(key types.FlowKey) (contracts.ManagedQueue, e
}
// allOrderedPriorityLevels provides the mock implementation for the `RegistryShard` interface.
-func (h *testHarness) allOrderedPriorityLevels() []uint {
+func (h *testHarness) allOrderedPriorityLevels() []int {
h.mu.Lock()
defer h.mu.Unlock()
- prios := make([]uint, 0, len(h.priorityFlows))
+ prios := make([]int, 0, len(h.priorityFlows))
for p := range h.priorityFlows {
prios = append(prios, p)
}
- slices.Sort(prios)
+ sort.Slice(prios, func(i, j int) bool {
+ return prios[i] > prios[j]
+ })
+
return prios
}
// priorityBandAccessor provides the mock implementation for the `RegistryShard` interface. It acts as a factory for a
// fully-configured, stateless mock that is safe for concurrent use.
-func (h *testHarness) priorityBandAccessor(p uint) (framework.PriorityBandAccessor, error) {
+func (h *testHarness) priorityBandAccessor(p int) (framework.PriorityBandAccessor, error) {
band := &frameworkmocks.MockPriorityBandAccessor{PriorityV: p}
// Safely get a snapshot of the flow IDs under a lock.
@@ -288,7 +248,7 @@ func (h *testHarness) priorityBandAccessor(p uint) (framework.PriorityBandAccess
}
// interFlowDispatchPolicy provides the mock implementation for the `contracts.RegistryShard` interface.
-func (h *testHarness) interFlowDispatchPolicy(p uint) (framework.InterFlowDispatchPolicy, error) {
+func (h *testHarness) interFlowDispatchPolicy(p int) (framework.InterFlowDispatchPolicy, error) {
policy := &frameworkmocks.MockInterFlowDispatchPolicy{}
// If the test provided a custom implementation, use it.
if h.interFlowPolicySelectQueue != nil {
@@ -331,9 +291,9 @@ func (h *testHarness) intraFlowDispatchPolicy(types.FlowKey) (framework.IntraFlo
func TestShardProcessor(t *testing.T) {
t.Parallel()
- // Lifecycle tests use the processor's main `Run` loop to verify the complete end-to-end lifecycle of a request, from
+ // Integration tests use the processor's main `Run` loop to verify the complete end-to-end lifecycle of a request, from
// `Enqueue` to its final outcome.
- t.Run("Lifecycle", func(t *testing.T) {
+ t.Run("Integration", func(t *testing.T) {
t.Parallel()
t.Run("should dispatch item successfully", func(t *testing.T) {
@@ -341,12 +301,11 @@ func TestShardProcessor(t *testing.T) {
// --- ARRANGE ---
h := newTestHarness(t, testCleanupTick)
item := h.newTestItem("req-dispatch-success", testFlow, testTTL)
- h.addQueue(types.FlowKey{ID: testFlow.ID, Priority: testFlow.Priority})
+ h.addQueue(testFlow)
// --- ACT ---
h.Start()
- defer h.Stop()
- h.processor.Enqueue(item)
+ require.NoError(t, h.processor.Submit(item), "precondition: Submit should not fail")
h.Go()
// --- ASSERT ---
@@ -362,15 +321,14 @@ func TestShardProcessor(t *testing.T) {
item := h.newTestItem("req-capacity-reject", testFlow, testTTL)
h.addQueue(testFlow)
h.StatsFunc = func() contracts.ShardStats {
- return contracts.ShardStats{PerPriorityBandStats: map[uint]contracts.PriorityBandStats{
+ return contracts.ShardStats{PerPriorityBandStats: map[int]contracts.PriorityBandStats{
testFlow.Priority: {CapacityBytes: 50}, // 50 is less than item size of 100
}}
}
// --- ACT ---
h.Start()
- defer h.Stop()
- h.processor.Enqueue(item)
+ require.NoError(t, h.processor.Submit(item), "precondition: Submit should not fail")
h.Go()
// --- ASSERT ---
@@ -393,7 +351,7 @@ func TestShardProcessor(t *testing.T) {
// --- ACT ---
h.Start()
defer h.Stop()
- h.processor.Enqueue(item)
+ require.NoError(t, h.processor.Submit(item), "precondition: Submit should not fail")
h.Go()
// --- ASSERT ---
@@ -413,94 +371,12 @@ func TestShardProcessor(t *testing.T) {
// --- ACT ---
h.Start()
h.Go()
- // Stop the processor, then immediately try to enqueue.
- h.Stop()
- h.processor.Enqueue(item)
-
- // --- ASSERT ---
- outcome, err := h.waitForFinalization(item)
- assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "The outcome should be RejectedOther")
- require.Error(t, err, "An eviction on shutdown should produce an error")
- assert.ErrorIs(t, err, types.ErrFlowControllerShutdown, "The error should be of type ErrFlowControllerShutdown")
- })
-
- t.Run("should evict item on TTL expiry via background cleanup", func(t *testing.T) {
- t.Parallel()
- // --- ARRANGE ---
- h := newTestHarness(t, testCleanupTick)
- item := h.newTestItem("req-expired-evict", testFlow, testShortTTL)
- h.addQueue(testFlow)
-
- // --- ACT ---
- h.Start()
- defer h.Stop()
- h.processor.Enqueue(item)
- h.Go()
-
- // Let time pass for the item to expire and for the background cleanup to run.
- h.mockClock.Advance(testShortTTL * 2)
- time.Sleep(testCleanupTick * 3) // Allow the cleanup goroutine time to run.
+ h.Stop() // Stop the processor, then immediately try to enqueue.
+ require.ErrorIs(t, h.processor.Submit(item), types.ErrFlowControllerNotRunning,
+ "Submit should return ErrFlowControllerNotRunning on shutdown")
// --- ASSERT ---
- outcome, err := h.waitForFinalization(item)
- assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The final outcome should be EvictedTTL")
- require.Error(t, err, "A TTL eviction should produce an error")
- assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired")
- })
-
- t.Run("should evict item at moment of dispatch if TTL has expired", func(t *testing.T) {
- t.Parallel()
- // --- ARRANGE ---
- h := newTestHarness(t, 1*time.Hour) // Disable background cleanup to isolate dispatch logic.
- item := h.newTestItem("req-expired-dispatch-evict", testFlow, testShortTTL)
- mockQueue := h.addQueue(testFlow)
- require.NoError(t, mockQueue.Add(item), "Adding item to mock queue should not fail")
-
- // Have the policy select the item, but then advance time so it's expired by the time dispatchItem actually runs.
- h.interFlowPolicySelectQueue = func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) {
- h.mockClock.Advance(testShortTTL * 2)
- return mockQueue.FlowQueueAccessor(), nil
- }
-
- // --- ACT ---
- h.Start()
- defer h.Stop()
- h.Go()
-
- // The run loop will pick up the item and attempt dispatch, which will fail internally.
- // We need a small sleep to allow the non-blocking run loop to process.
- time.Sleep(50 * time.Millisecond)
-
- // --- ASSERT ---
- outcome, err := h.waitForFinalization(item)
- assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The final outcome should be EvictedTTL")
- require.Error(t, err, "An eviction at dispatch time should produce an error")
- assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired")
- })
-
- t.Run("should evict item on context cancellation", func(t *testing.T) {
- t.Parallel()
- // --- ARRANGE ---
- h := newTestHarness(t, testCleanupTick)
- ctx, cancel := context.WithCancel(context.Background())
- req := typesmocks.NewMockFlowControlRequest(100, "req-ctx-cancel", testFlow, ctx)
- item := NewItem(req, testTTL, h.mockClock.Now())
- h.addQueue(testFlow)
-
- // --- ACT ---
- h.Start()
- defer h.Stop()
- h.processor.Enqueue(item)
- h.Go()
- cancel() // Cancel the context *after* the item is enqueued.
- time.Sleep(testCleanupTick * 3) // Allow the cleanup goroutine time to run.
-
- // --- ASSERT ---
- outcome, err := h.waitForFinalization(item)
- assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, outcome,
- "The outcome should be EvictedContextCancelled")
- require.Error(t, err, "A context cancellation eviction should produce an error")
- assert.ErrorIs(t, err, types.ErrContextCancelled, "The error should be of type ErrContextCancelled")
+ assert.Nil(t, item.FinalState(), "Item should not be finalized by the processor")
})
t.Run("should evict a queued item on shutdown", func(t *testing.T) {
@@ -525,7 +401,8 @@ func TestShardProcessor(t *testing.T) {
outcome, err := h.waitForFinalization(item)
assert.Equal(t, types.QueueOutcomeEvictedOther, outcome, "The outcome should be EvictedOther")
require.Error(t, err, "An eviction on shutdown should produce an error")
- assert.ErrorIs(t, err, types.ErrFlowControllerShutdown, "The error should be of type ErrFlowControllerShutdown")
+ assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning,
+ "The error should be of type ErrFlowControllerNotRunning")
})
t.Run("should handle concurrent enqueues and dispatch all items", func(t *testing.T) {
@@ -534,8 +411,8 @@ func TestShardProcessor(t *testing.T) {
h := newTestHarness(t, testCleanupTick)
const numConcurrentItems = 20
q := h.addQueue(testFlow)
- itemsToTest := make([]*flowItem, 0, numConcurrentItems)
- for i := 0; i < numConcurrentItems; i++ {
+ itemsToTest := make([]*FlowItem, 0, numConcurrentItems)
+ for i := range numConcurrentItems {
item := h.newTestItem(fmt.Sprintf("req-concurrent-%d", i), testFlow, testTTL)
itemsToTest = append(itemsToTest, item)
}
@@ -546,9 +423,9 @@ func TestShardProcessor(t *testing.T) {
var wg sync.WaitGroup
for _, item := range itemsToTest {
wg.Add(1)
- go func(fi *flowItem) {
+ go func(fi *FlowItem) {
defer wg.Done()
- h.processor.Enqueue(fi)
+ require.NoError(t, h.processor.Submit(fi), "Submit should not fail")
}(item)
}
h.Go()
@@ -576,16 +453,26 @@ func TestShardProcessor(t *testing.T) {
// Use channels to pause the dispatch cycle right before it would remove the item.
policyCanProceed := make(chan struct{})
itemIsBeingDispatched := make(chan struct{})
+ var signalOnce sync.Once
+ var removedItem types.QueueItemAccessor
require.NoError(t, q.Add(item)) // Add the item directly to the queue.
// Override the queue's `RemoveFunc` to pause the dispatch goroutine at a critical moment.
q.RemoveFunc = func(h types.QueueItemHandle) (types.QueueItemAccessor, error) {
- close(itemIsBeingDispatched) // 1. Signal that dispatch is happening.
- <-policyCanProceed // 2. Wait for the test to tell us to continue.
- // 4. After we unblock, the item will have already been finalized by the cleanup logic, so we simulate the
- // real-world outcome of a failed remove.
- return nil, fmt.Errorf("item with handle %v not found", h)
+ var err error
+ signalOnce.Do(func() {
+ removedItem = item
+ close(itemIsBeingDispatched) // 1. Signal that dispatch is happening.
+ <-policyCanProceed // 2. Wait for the test to tell us to continue.
+ // 4. After we unblock, the item will have already been finalized by the cleanup logic.
+ // We simulate the item no longer being found.
+ err = fmt.Errorf("item with handle %v not found", h)
+ })
+ if removedItem == item {
+ return item, nil // Return the item on the first call
+ }
+ return nil, err // Return error on subsequent calls
}
// --- ACT ---
@@ -594,20 +481,23 @@ func TestShardProcessor(t *testing.T) {
h.Go()
// Wait for the dispatch cycle to select our item and pause inside our mock `RemoveFunc`.
- <-itemIsBeingDispatched
+ select {
+ case <-itemIsBeingDispatched:
+ case <-time.After(testWaitTimeout):
+ t.Fatal("Timed out waiting for item to be dispatched")
+ }
// 3. The dispatch goroutine is now paused. We can now safely win the "race" by running cleanup logic.
- h.mockClock.Advance(testShortTTL * 2)
- h.processor.cleanupExpired(h.mockClock.Now()) // This will remove and finalize the item.
+ h.clock.Step(testShortTTL * 2)
+ item.Finalize(types.ErrTTLExpired) // This will finalize the item with RejectedOther.
- // 5. Un-pause the dispatch goroutine. It will now fail to remove the item and the `dispatchCycle` will
- // correctly conclude without finalizing the item a second time.
+ // 5. Un-pause the dispatch goroutine.
close(policyCanProceed)
// --- ASSERT ---
- // The item's final state should be from the cleanup logic (EvictedTTL), not the dispatch logic.
+ // The item's final state should be from the Finalize call above.
outcome, err := h.waitForFinalization(item)
- assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The outcome should be EvictedTTL from the cleanup routine")
+ assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The outcome should be EvictedTTL from the Finalize call")
require.Error(t, err, "A TTL eviction should produce an error")
assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired")
})
@@ -647,7 +537,7 @@ func TestShardProcessor(t *testing.T) {
h.Start()
defer h.Stop()
h.Go()
- h.processor.Enqueue(nil)
+ require.NoError(t, h.processor.Submit(nil), "Submit should not fail")
// --- ASSERT ---
// Allow a moment for the processor to potentially process the nil item.
@@ -666,32 +556,32 @@ func TestShardProcessor(t *testing.T) {
testCases := []struct {
name string
setupHarness func(h *testHarness)
- item *flowItem
- assert func(t *testing.T, h *testHarness, item *flowItem)
+ item *FlowItem
+ assert func(t *testing.T, h *testHarness, item *FlowItem)
}{
{
name: "should reject item on registry queue lookup failure",
setupHarness: func(h *testHarness) {
h.ManagedQueueFunc = func(types.FlowKey) (contracts.ManagedQueue, error) { return nil, testErr }
},
- assert: func(t *testing.T, h *testHarness, item *flowItem) {
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "Outcome should be RejectedOther")
- require.Error(t, err, "An error should be returned")
- assert.ErrorIs(t, err, testErr, "The underlying error should be preserved")
+ assert: func(t *testing.T, h *testHarness, item *FlowItem) {
+ assert.Equal(t, types.QueueOutcomeRejectedOther, item.FinalState().Outcome,
+ "Outcome should be RejectedOther")
+ require.Error(t, item.FinalState().Err, "An error should be returned")
+ assert.ErrorIs(t, item.FinalState().Err, testErr, "The underlying error should be preserved")
},
},
{
name: "should reject item on registry priority band lookup failure",
setupHarness: func(h *testHarness) {
h.addQueue(testFlow)
- h.PriorityBandAccessorFunc = func(uint) (framework.PriorityBandAccessor, error) { return nil, testErr }
+ h.PriorityBandAccessorFunc = func(int) (framework.PriorityBandAccessor, error) { return nil, testErr }
},
- assert: func(t *testing.T, h *testHarness, item *flowItem) {
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "Outcome should be RejectedOther")
- require.Error(t, err, "An error should be returned")
- assert.ErrorIs(t, err, testErr, "The underlying error should be preserved")
+ assert: func(t *testing.T, h *testHarness, item *FlowItem) {
+ assert.Equal(t, types.QueueOutcomeRejectedOther, item.FinalState().Outcome,
+ "Outcome should be RejectedOther")
+ require.Error(t, item.FinalState().Err, "An error should be returned")
+ assert.ErrorIs(t, item.FinalState().Err, testErr, "The underlying error should be preserved")
},
},
{
@@ -700,11 +590,11 @@ func TestShardProcessor(t *testing.T) {
mockQueue := h.addQueue(testFlow)
mockQueue.AddFunc = func(types.QueueItemAccessor) error { return testErr }
},
- assert: func(t *testing.T, h *testHarness, item *flowItem) {
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "Outcome should be RejectedOther")
- require.Error(t, err, "An error should be returned")
- assert.ErrorIs(t, err, testErr, "The underlying error should be preserved")
+ assert: func(t *testing.T, h *testHarness, item *FlowItem) {
+ assert.Equal(t, types.QueueOutcomeRejectedOther, item.FinalState().Outcome,
+ "Outcome should be RejectedOther")
+ require.Error(t, item.FinalState().Err, "An error should be returned")
+ assert.ErrorIs(t, item.FinalState().Err, testErr, "The underlying error should be preserved")
},
},
{
@@ -721,17 +611,16 @@ func TestShardProcessor(t *testing.T) {
assert.Equal(t, 0, addCallCount, "Queue.Add should not have been called for a finalized item")
})
},
- item: func() *flowItem {
+ item: func() *FlowItem {
// Create a pre-finalized item.
item := newTestHarness(t, 0).newTestItem("req-finalized", testFlow, testTTL)
- item.finalize(types.QueueOutcomeDispatched, nil)
+ item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil)
return item
}(),
- assert: func(t *testing.T, h *testHarness, item *flowItem) {
+ assert: func(t *testing.T, h *testHarness, item *FlowItem) {
// The item was already finalized, so its state should not change.
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeDispatched, outcome, "Outcome should remain unchanged")
- assert.NoError(t, err, "Error should remain unchanged")
+ assert.Equal(t, types.QueueOutcomeDispatched, item.FinalState().Outcome, "Outcome should remain unchanged")
+ assert.NoError(t, item.FinalState().Err, "Error should remain unchanged")
},
},
}
@@ -776,7 +665,7 @@ func TestShardProcessor(t *testing.T) {
itemByteSize: 1,
stats: contracts.ShardStats{
TotalCapacityBytes: 200, TotalByteSize: 100,
- PerPriorityBandStats: map[uint]contracts.PriorityBandStats{
+ PerPriorityBandStats: map[int]contracts.PriorityBandStats{
testFlow.Priority: {ByteSize: 50, CapacityBytes: 50},
},
},
@@ -787,7 +676,7 @@ func TestShardProcessor(t *testing.T) {
itemByteSize: 1,
stats: contracts.ShardStats{
TotalCapacityBytes: 200, TotalByteSize: 100,
- PerPriorityBandStats: map[uint]contracts.PriorityBandStats{}, // Missing stats for priority 10
+ PerPriorityBandStats: map[int]contracts.PriorityBandStats{}, // Missing stats for priority 10
},
expectHasCap: false,
},
@@ -796,7 +685,7 @@ func TestShardProcessor(t *testing.T) {
itemByteSize: 10,
stats: contracts.ShardStats{
TotalCapacityBytes: 200, TotalByteSize: 100,
- PerPriorityBandStats: map[uint]contracts.PriorityBandStats{
+ PerPriorityBandStats: map[int]contracts.PriorityBandStats{
testFlow.Priority: {ByteSize: 50, CapacityBytes: 100},
},
},
@@ -836,17 +725,19 @@ func TestShardProcessor(t *testing.T) {
expectDidDispatch: false,
},
{
- name: "should stop dispatching when filter signals pause",
+ name: "should block dispatch on HoL saturation",
setupHarness: func(h *testHarness) {
- // Add an item that *could* be dispatched to prove the pause is effective.
- q := h.addQueue(testFlow)
- require.NoError(t, q.Add(h.newTestItem("item", testFlow, testTTL)))
- h.processor.dispatchFilter = func(
- _ context.Context,
- _ framework.PriorityBandAccessor,
- _ logr.Logger,
- ) (framework.PriorityBandAccessor, bool) {
- return nil, true // Signal pause.
+ // Add a high-priority item that will be selected but is saturated.
+ qHigh := h.addQueue(testFlow) // priority 10
+ require.NoError(t, qHigh.Add(h.newTestItem("item-high", testFlow, testTTL)))
+
+ // Add a low-priority, viable item.
+ keyLow := types.FlowKey{ID: "flow-low", Priority: 5}
+ qLow := h.addQueue(keyLow)
+ require.NoError(t, qLow.Add(h.newTestItem("item-low", keyLow, testTTL)))
+
+ h.saturationDetector.IsSaturatedFunc = func(_ context.Context, _ []metrics.PodMetrics) bool {
+ return true
}
},
expectDidDispatch: false,
@@ -854,7 +745,7 @@ func TestShardProcessor(t *testing.T) {
{
name: "should skip band on priority band accessor error",
setupHarness: func(h *testHarness) {
- h.PriorityBandAccessorFunc = func(uint) (framework.PriorityBandAccessor, error) {
+ h.PriorityBandAccessorFunc = func(int) (framework.PriorityBandAccessor, error) {
return nil, registryErr
}
},
@@ -955,62 +846,18 @@ func TestShardProcessor(t *testing.T) {
}
})
- t.Run("should use filtered view of queues when filter is active", func(t *testing.T) {
- t.Parallel()
- // --- ARRANGE ---
- h := newTestHarness(t, testCleanupTick)
- flowB := types.FlowKey{ID: "flow-b", Priority: testFlow.Priority}
- h.addQueue(testFlow)
- qB := h.addQueue(flowB)
- itemB := h.newTestItem("item-b", flowB, testTTL)
- require.NoError(t, qB.Add(itemB))
-
- // This filter only allows flow-b.
- h.processor.dispatchFilter = func(
- _ context.Context,
- originalBand framework.PriorityBandAccessor,
- _ logr.Logger,
- ) (framework.PriorityBandAccessor, bool) {
- return newSubsetPriorityBandAccessor(originalBand, []types.FlowKey{flowB}), false
- }
-
- // This policy will be given the filtered view, so it should only see flow-b.
- h.interFlowPolicySelectQueue = func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) {
- var flowIDs []string
- band.IterateQueues(func(fqa framework.FlowQueueAccessor) bool {
- flowIDs = append(flowIDs, fqa.FlowKey().ID)
- return true
- })
- // This is the core assertion of the test.
- require.ElementsMatch(t, []string{flowB.ID}, flowIDs, "Policy should only see the filtered flow")
-
- // Select flow-b to prove the chain works.
- q, _ := h.managedQueue(flowB)
- return q.FlowQueueAccessor(), nil
- }
-
- // --- ACT ---
- dispatched := h.processor.dispatchCycle(context.Background())
-
- // --- ASSERT ---
- assert.True(t, dispatched, "An item should have been dispatched from the filtered flow")
- outcome, err := itemB.FinalState()
- assert.Equal(t, types.QueueOutcomeDispatched, outcome, "The dispatched item's outcome should be correct")
- assert.NoError(t, err, "The dispatched item should not have an error")
- })
-
t.Run("should guarantee strict priority by starving lower priority items", func(t *testing.T) {
t.Parallel()
// --- ARRANGE ---
h := newTestHarness(t, testCleanupTick)
- keyHigh := types.FlowKey{ID: "flow-high", Priority: 10}
- keyLow := types.FlowKey{ID: "flow-low", Priority: 20}
+ keyHigh := types.FlowKey{ID: "flow-high", Priority: 20}
+ keyLow := types.FlowKey{ID: "flow-low", Priority: 10}
qHigh := h.addQueue(keyHigh)
qLow := h.addQueue(keyLow)
const numItems = 3
- highPrioItems := make([]*flowItem, numItems)
- lowPrioItems := make([]*flowItem, numItems)
+ highPrioItems := make([]*FlowItem, numItems)
+ lowPrioItems := make([]*FlowItem, numItems)
for i := range numItems {
// Add high priority items.
itemH := h.newTestItem(fmt.Sprintf("req-high-%d", i), keyHigh, testTTL)
@@ -1032,9 +879,9 @@ func TestShardProcessor(t *testing.T) {
// Verify all high-priority items are gone and low-priority items remain.
for _, item := range highPrioItems {
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeDispatched, outcome, "High-priority item should be dispatched")
- assert.NoError(t, err, "Dispatched high-priority item should not have an error")
+ assert.Equal(t, types.QueueOutcomeDispatched, item.FinalState().Outcome,
+ "High-priority item should be dispatched")
+ assert.NoError(t, item.FinalState().Err, "Dispatched high-priority item should not have an error")
}
assert.Equal(t, numItems, qLow.Len(), "Low-priority queue should still be full")
@@ -1066,19 +913,6 @@ func TestShardProcessor(t *testing.T) {
},
expectedErr: registryErr,
},
- {
- name: "on queue remove failure",
- setupMocks: func(h *testHarness) {
- h.ManagedQueueFunc = func(types.FlowKey) (contracts.ManagedQueue, error) {
- return &mocks.MockManagedQueue{
- RemoveFunc: func(types.QueueItemHandle) (types.QueueItemAccessor, error) {
- return nil, registryErr
- },
- }, nil
- }
- },
- expectedErr: registryErr,
- },
}
for _, tc := range testCases {
@@ -1087,18 +921,19 @@ func TestShardProcessor(t *testing.T) {
h := newTestHarness(t, testCleanupTick)
tc.setupMocks(h)
item := h.newTestItem("req-dispatch-fail", testFlow, testTTL)
- err := h.processor.dispatchItem(item, h.logger)
+ err := h.processor.dispatchItem(item)
require.Error(t, err, "dispatchItem should return an error")
assert.ErrorIs(t, err, tc.expectedErr, "The underlying registry error should be preserved")
})
}
})
- t.Run("should evict item that expires at moment of dispatch", func(t *testing.T) {
+ t.Run("should not dispatch already finalized item", func(t *testing.T) {
t.Parallel()
// --- ARRANGE ---
h := newTestHarness(t, testCleanupTick)
- item := h.newTestItem("req-expired-dispatch", testFlow, testShortTTL)
+ item := h.newTestItem("req-already-finalized", testFlow, testTTL)
+ item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, errors.New("already done"))
h.ManagedQueueFunc = func(types.FlowKey) (contracts.ManagedQueue, error) {
return &mocks.MockManagedQueue{
@@ -1109,69 +944,61 @@ func TestShardProcessor(t *testing.T) {
}
// --- ACT ---
- h.mockClock.Advance(testShortTTL * 2) // Make the item expire.
- err := h.processor.dispatchItem(item, h.logger)
+ err := h.processor.dispatchItem(item)
// --- ASSERT ---
- // First, check the error returned by `dispatchItem`.
- require.Error(t, err, "dispatchItem should return an error for an expired item")
- assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired")
-
- // Second, check the final state of the item itself.
- outcome, finalErr := item.FinalState()
- assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The item's final outcome should be EvictedTTL")
- require.Error(t, finalErr, "The item's final state should contain an error")
- assert.ErrorIs(t, finalErr, types.ErrTTLExpired, "The item's final error should be of type ErrTTLExpired")
+ require.NoError(t, err, "dispatchItem should return no error for an already finalized item")
+
+ // Check the final state of the item itself - it should not have changed.
+ finalState := item.FinalState()
+ require.NotNil(t, finalState, "Item must be finalized")
+ assert.Equal(t, types.QueueOutcomeRejectedOther, finalState.Outcome,
+ "The item's final outcome should be RejectedOther")
+ assert.ErrorContains(t, finalState.Err, "already done",
+ "The error should be the one from the first Finalize call")
})
+ })
- t.Run("should panic if queue returns item of wrong type", func(t *testing.T) {
+ t.Run("cleanup and utility methods", func(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should sweep externally finalized items", func(t *testing.T) {
t.Parallel()
// --- ARRANGE ---
h := newTestHarness(t, testCleanupTick)
- badItem := &typesmocks.MockQueueItemAccessor{
- OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "bad-item", testFlow, context.Background()),
- }
+ item := h.newTestItem("req-external-finalized", testFlow, testTTL)
+ q := h.addQueue(testFlow)
+ require.NoError(t, q.Add(item), "Failed to add item to queue")
- h.ManagedQueueFunc = func(types.FlowKey) (contracts.ManagedQueue, error) {
- return &mocks.MockManagedQueue{
- RemoveFunc: func(types.QueueItemHandle) (types.QueueItemAccessor, error) {
- return badItem, nil
- },
- }, nil
- }
+ // Externally finalize the item
+ item.Finalize(context.Canceled)
+ require.NotNil(t, item.FinalState(), "Item should be finalized")
- itemToDispatch := h.newTestItem("req-dispatch-panic", testFlow, testTTL)
- expectedPanicMsg := fmt.Sprintf("%s: internal error: item %q of type %T is not a *flowItem",
- errIntraFlow, "bad-item", badItem)
+ // --- ACT ---
+ h.processor.sweepFinalizedItems()
- // --- ACT & ASSERT ---
- assert.PanicsWithError(t, expectedPanicMsg, func() {
- _ = h.processor.dispatchItem(itemToDispatch, h.logger)
- }, "A type mismatch from a queue should cause a panic")
+ // --- ASSERT ---
+ assert.Equal(t, 0, q.Len(), "Queue should be empty after sweep")
+ finalState := item.FinalState()
+ assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, finalState.Outcome,
+ "Outcome should be EvictedContextCancelled")
+ assert.ErrorIs(t, finalState.Err, types.ErrContextCancelled, "Error should be ErrContextCancelled")
})
- })
-
- t.Run("cleanup and utility methods", func(t *testing.T) {
- t.Parallel()
- t.Run("should remove and finalize expired items", func(t *testing.T) {
+ t.Run("should not sweep items not finalized", func(t *testing.T) {
t.Parallel()
// --- ARRANGE ---
h := newTestHarness(t, testCleanupTick)
- // Create an item that is already expired relative to the cleanup time.
- item := h.newTestItem("req-expired", testFlow, 1*time.Millisecond)
+ item := h.newTestItem("req-not-finalized", testFlow, testTTL)
q := h.addQueue(testFlow)
- require.NoError(t, q.Add(item))
- cleanupTime := h.mockClock.Now().Add(10 * time.Millisecond)
+ require.NoError(t, q.Add(item), "Failed to add item to queue")
// --- ACT ---
- h.processor.cleanupExpired(cleanupTime)
+ h.processor.sweepFinalizedItems()
// --- ASSERT ---
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "Item outcome should be EvictedTTL")
- require.Error(t, err, "Item should have an error")
- assert.ErrorIs(t, err, types.ErrTTLExpired, "Item error should be ErrTTLExpired")
+ assert.Equal(t, 1, q.Len(), "Queue should still contain the item")
+ assert.Nil(t, item.FinalState(), "Item should not be finalized")
})
t.Run("should evict all items on shutdown", func(t *testing.T) {
@@ -1186,18 +1013,19 @@ func TestShardProcessor(t *testing.T) {
h.processor.evictAll()
// --- ASSERT ---
- outcome, err := item.FinalState()
- assert.Equal(t, types.QueueOutcomeEvictedOther, outcome, "Item outcome should be EvictedOther")
- require.Error(t, err, "Item should have an error")
- assert.ErrorIs(t, err, types.ErrFlowControllerShutdown, "Item error should be ErrFlowControllerShutdown")
+ assert.Equal(t, types.QueueOutcomeEvictedOther, item.FinalState().Outcome,
+ "Item outcome should be EvictedOther")
+ require.Error(t, item.FinalState().Err, "Item should have an error")
+ assert.ErrorIs(t, item.FinalState().Err, types.ErrFlowControllerNotRunning,
+ "Item error should be ErrFlowControllerNotRunning")
})
t.Run("should handle registry errors gracefully during concurrent processing", func(t *testing.T) {
t.Parallel()
// --- ARRANGE ---
h := newTestHarness(t, testCleanupTick)
- h.AllOrderedPriorityLevelsFunc = func() []uint { return []uint{testFlow.Priority} }
- h.PriorityBandAccessorFunc = func(p uint) (framework.PriorityBandAccessor, error) {
+ h.AllOrderedPriorityLevelsFunc = func() []int { return []int{testFlow.Priority} }
+ h.PriorityBandAccessorFunc = func(p int) (framework.PriorityBandAccessor, error) {
return nil, errors.New("registry error")
}
@@ -1208,25 +1036,6 @@ func TestShardProcessor(t *testing.T) {
}, "processAllQueuesConcurrently should not panic on registry errors")
})
- t.Run("should handle items of an unexpected type gracefully during finalization", func(t *testing.T) {
- t.Parallel()
- // --- ARRANGE ---
- h := newTestHarness(t, testCleanupTick)
- item := &typesmocks.MockQueueItemAccessor{
- OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "bad-item", testFlow, context.Background()),
- }
- items := []types.QueueItemAccessor{item}
-
- // --- ACT & ASSERT ---
- // The test passes if this call completes without panicking.
- assert.NotPanics(t, func() {
- getOutcome := func(types.QueueItemAccessor) (types.QueueOutcome, error) {
- return types.QueueOutcomeEvictedOther, nil
- }
- h.processor.finalizeItems(items, h.logger, getOutcome)
- }, "finalizeItems should not panic on unexpected item types")
- })
-
t.Run("should process all queues with a worker pool", func(t *testing.T) {
t.Parallel()
// --- ARRANGE ---
@@ -1257,122 +1066,120 @@ func TestShardProcessor(t *testing.T) {
})
})
})
-}
-func TestCheckItemExpiry(t *testing.T) {
- t.Parallel()
+ t.Run("Public API", func(t *testing.T) {
+ t.Parallel()
- // --- ARRANGE ---
- now := time.Now()
- ctxCancelled, cancel := context.WithCancel(context.Background())
- cancel() // Cancel the context immediately.
-
- testCases := []struct {
- name string
- item types.QueueItemAccessor
- now time.Time
- expectExpired bool
- expectOutcome types.QueueOutcome
- expectErr error
- }{
- {
- name: "should not be expired if TTL is not reached and context is active",
- item: NewItem(
- typesmocks.NewMockFlowControlRequest(100, "req-not-expired", testFlow, context.Background()),
- testTTL,
- now),
- now: now.Add(30 * time.Second),
- expectExpired: false,
- expectOutcome: types.QueueOutcomeNotYetFinalized,
- expectErr: nil,
- },
- {
- name: "should not be expired if TTL is disabled (0)",
- item: NewItem(
- typesmocks.NewMockFlowControlRequest(100, "req-not-expired-no-ttl", testFlow, context.Background()),
- 0,
- now),
- now: now.Add(30 * time.Second),
- expectExpired: false,
- expectOutcome: types.QueueOutcomeNotYetFinalized,
- expectErr: nil,
- },
- {
- name: "should be expired if TTL is exceeded",
- item: NewItem(
- typesmocks.NewMockFlowControlRequest(100, "req-ttl-expired", testFlow, context.Background()),
- time.Second,
- now),
- now: now.Add(2 * time.Second),
- expectExpired: true,
- expectOutcome: types.QueueOutcomeEvictedTTL,
- expectErr: types.ErrTTLExpired,
- },
- {
- name: "should be expired if context is cancelled",
- item: NewItem(
- typesmocks.NewMockFlowControlRequest(100, "req-ctx-cancelled", testFlow, ctxCancelled),
- testTTL,
- now),
- now: now,
- expectExpired: true,
- expectOutcome: types.QueueOutcomeEvictedContextCancelled,
- expectErr: types.ErrContextCancelled,
- },
- {
- name: "should be expired if already finalized",
- item: func() types.QueueItemAccessor {
- i := NewItem(
- typesmocks.NewMockFlowControlRequest(100, "req-finalized", testFlow, context.Background()),
- testTTL,
- now)
- i.finalize(types.QueueOutcomeDispatched, nil)
- return i
- }(),
- now: now,
- expectExpired: true,
- expectOutcome: types.QueueOutcomeDispatched,
- expectErr: nil,
- },
- }
+ t.Run("Submit", func(t *testing.T) {
+ t.Parallel()
+
+ t.Run("should return ErrProcessorBusy when channel is full", func(t *testing.T) {
+ t.Parallel()
+ h := newTestHarness(t, testCleanupTick)
+ h.processor.enqueueChan = make(chan *FlowItem, 1)
+ h.processor.enqueueChan <- h.newTestItem("item-filler", testFlow, testTTL) // Fill the channel to capacity.
+
+ // The next submit should be non-blocking and fail immediately.
+ err := h.processor.Submit(h.newTestItem("item-to-reject", testFlow, testTTL))
+ require.Error(t, err, "Submit must return an error when the channel is full")
+ assert.ErrorIs(t, err, ErrProcessorBusy, "The returned error must be ErrProcessorBusy")
+ })
+
+ t.Run("should return ErrFlowControllerNotRunning if lifecycleCtx is cancelled", func(t *testing.T) {
+ t.Parallel()
+ h := newTestHarness(t, testCleanupTick)
+ h.Start()
+ h.Go() // Ensure the Run loop has started
+ h.cancel() // Cancel the lifecycle context
+ h.Stop() // Wait for the processor to fully stop
+
+ item := h.newTestItem("item-ctx-cancel", testFlow, testTTL)
+ err := h.processor.Submit(item)
+ require.ErrorIs(t, err, types.ErrFlowControllerNotRunning,
+ "Submit must return ErrFlowControllerNotRunning when lifecycleCtx is cancelled")
+ assert.Nil(t, item.FinalState(), "Item should not be finalized by Submit")
+
+ err = h.processor.SubmitOrBlock(context.Background(), item)
+ require.ErrorIs(t, err, types.ErrFlowControllerNotRunning,
+ "SubmitOrBlock must return ErrFlowControllerNotRunning when lifecycleCtx is cancelled")
+ assert.Nil(t, item.FinalState(), "Item should not be finalized by SubmitOrBlock")
+ })
+ })
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
+ t.Run("SubmitOrBlock", func(t *testing.T) {
t.Parallel()
- // --- ACT ---
- isExpired, outcome, err := checkItemExpiry(tc.item, tc.now)
- // --- ASSERT ---
- assert.Equal(t, tc.expectExpired, isExpired, "Expired status should match expected value")
- assert.Equal(t, tc.expectOutcome, outcome, "Outcome should match expected value")
-
- if tc.expectErr != nil {
- require.Error(t, err, "An error was expected")
- // Use ErrorIs for sentinel errors, ErrorContains for general messages.
- if errors.Is(tc.expectErr, types.ErrTTLExpired) || errors.Is(tc.expectErr, types.ErrContextCancelled) {
- assert.ErrorIs(t, err, tc.expectErr, "The specific error type should be correct")
- } else {
- assert.ErrorContains(t, err, tc.expectErr.Error(), "The error message should contain the expected text")
+ t.Run("should block when channel is full and succeed when space becomes available", func(t *testing.T) {
+ t.Parallel()
+ h := newTestHarness(t, testCleanupTick)
+ h.processor.enqueueChan = make(chan *FlowItem, 1)
+ h.processor.enqueueChan <- h.newTestItem("item-filler", testFlow, testTTL) // Fill the channel to capacity.
+
+ itemToSubmit := h.newTestItem("item-to-block", testFlow, testTTL)
+ submitErr := make(chan error, 1)
+
+ // Run `SubmitOrBlock` in a separate goroutine, as it will block.
+ go func() {
+ submitErr <- h.processor.SubmitOrBlock(context.Background(), itemToSubmit)
+ }()
+
+ // Prove that the call is blocking by ensuring it hasn't returned an error yet.
+ time.Sleep(20 * time.Millisecond)
+ require.Len(t, submitErr, 0, "SubmitOrBlock should be blocking and not have returned yet")
+ <-h.processor.enqueueChan // Make space in the channel. This should unblock the goroutine.
+
+ select {
+ case err := <-submitErr:
+ require.NoError(t, err, "SubmitOrBlock should succeed and return no error after being unblocked")
+ case <-time.After(testWaitTimeout):
+ t.Fatal("SubmitOrBlock did not return after space was made in the channel")
}
- } else {
- assert.NoError(t, err, "No error was expected")
- }
- })
- }
+ })
- t.Run("should panic on item of an unexpected type", func(t *testing.T) {
- t.Parallel()
- // --- ARRANGE ---
- badItem := &typesmocks.MockQueueItemAccessor{
- OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "item-bad-type", testFlow, context.Background()),
- }
+ t.Run("should unblock and return context error on cancellation", func(t *testing.T) {
+ t.Parallel()
+ h := newTestHarness(t, testCleanupTick)
+ h.processor.enqueueChan = make(chan *FlowItem) // Use an unbuffered channel to guarantee the first send blocks.
+ itemToSubmit := h.newTestItem("item-to-cancel", testFlow, testTTL)
+ submitErr := make(chan error, 1)
+ ctx, cancel := context.WithCancel(context.Background())
+
+ // Run `SubmitOrBlock` in a separate goroutine, as it will block.
+ go func() {
+ submitErr <- h.processor.SubmitOrBlock(ctx, itemToSubmit)
+ }()
+
+ // Prove that the call is blocking.
+ time.Sleep(20 * time.Millisecond)
+ require.Len(t, submitErr, 0, "SubmitOrBlock should be blocking and not have returned yet")
+ cancel() // Cancel the context. This should unblock the goroutine.
+
+ select {
+ case err := <-submitErr:
+ require.Error(t, err, "SubmitOrBlock should return an error after context cancellation")
+ assert.ErrorIs(t, err, context.Canceled, "The returned error must be context.Canceled")
+ case <-time.After(testWaitTimeout):
+ t.Fatal("SubmitOrBlock did not return after context was cancelled")
+ }
+ })
+
+ t.Run("should reject immediately if shutting down", func(t *testing.T) {
+ t.Parallel()
+ h := newTestHarness(t, testCleanupTick)
+ item := h.newTestItem("req-shutdown-reject", testFlow, testTTL)
+ h.addQueue(testFlow)
+
+ h.Start()
+ h.Go()
+ h.Stop() // Stop the processor, then immediately try to enqueue.
+ err := h.processor.SubmitOrBlock(context.Background(), item)
- expectedPanicMsg := fmt.Sprintf("internal error: item %q of type %T is not a *flowItem",
- badItem.OriginalRequestV.ID(), badItem)
+ require.Error(t, err, "SubmitOrBlock should return an error when shutting down")
+ assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, "The error should be ErrFlowControllerNotRunning")
- // --- ACT & ASSERT ---
- assert.PanicsWithError(t, expectedPanicMsg, func() {
- _, _, _ = checkItemExpiry(badItem, time.Now())
- }, "A type mismatch from a queue should cause a panic")
+ // Item should not be finalized by the processor
+ assert.Nil(t, item.FinalState(), "Item should not be finalized by the processor")
+ })
+ })
})
}
diff --git a/pkg/epp/flowcontrol/framework/mocks/mocks.go b/pkg/epp/flowcontrol/framework/mocks/mocks.go
index b8715b779..ff8441fde 100644
--- a/pkg/epp/flowcontrol/framework/mocks/mocks.go
+++ b/pkg/epp/flowcontrol/framework/mocks/mocks.go
@@ -67,14 +67,14 @@ var _ framework.FlowQueueAccessor = &MockFlowQueueAccessor{}
// Simple accessors are configured with public value fields (e.g., `PriorityV`).
// Complex methods with logic are configured with function fields (e.g., `IterateQueuesFunc`).
type MockPriorityBandAccessor struct {
- PriorityV uint
+ PriorityV int
PriorityNameV string
FlowKeysFunc func() []types.FlowKey
QueueFunc func(flowID string) framework.FlowQueueAccessor
IterateQueuesFunc func(callback func(queue framework.FlowQueueAccessor) (keepIterating bool))
}
-func (m *MockPriorityBandAccessor) Priority() uint { return m.PriorityV }
+func (m *MockPriorityBandAccessor) Priority() int { return m.PriorityV }
func (m *MockPriorityBandAccessor) PriorityName() string { return m.PriorityNameV }
func (m *MockPriorityBandAccessor) FlowKeys() []types.FlowKey {
diff --git a/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs.go b/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs.go
index edcc02ac6..7addb9d13 100644
--- a/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs.go
+++ b/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs.go
@@ -26,6 +26,32 @@ import (
)
// FCFSPolicyName is the name of the FCFS policy implementation.
+//
+// This policy implements a First-Come, First-Served (FCFS) strategy by selecting the item with the earliest logical
+// enqueue time.
+//
+// # Behavior and Queue Pairing
+//
+// The behavioral guarantees of this policy are critically dependent on the capabilities of the `framework.SafeQueue` it
+// is paired with. The system distinguishes between:
+// - "Logical Enqueue Time": The timestamp when a request first arrives at the `controller.FlowController`.
+// - "Physical Enqueue Time": The timestamp when a request is added to a specific shard's queue, which happens later.
+//
+// This policy's behavior changes accordingly:
+// - Paired with a `CapabilityPriorityConfigurable` queue, it provides strict FCFS ordering based on logical enqueue
+// time, aligning with this policy's vended `framework.ItemComparator`.
+// This configuration ensures that requests are processed in the order they arrived at the controller, providing the
+// most intuitive behavior.
+// - Paired with a `CapabilityFIFO` queue, it provides approximate FCFS ordering based on physical arrival order at
+// the `framework.SafeQueue`.
+// This configuration offers higher performance at the cost of strict logical-time ordering, as the
+// `controller.FlowController`'s "bounce-and-retry" mechanic for Draining shards means a bounced request may be
+// processed after a request that logically darrived later.
+//
+// Given that true end-to-end ordering is non-deterministic in a distributd system, this policy defaults to pairing with
+// a CapabilityFIFO` queue (like "ListQueue") to prioritize performance and high throughput. For users who require the
+// strictest possible logical-time ordering that this layer can provide, explicitly pairing this policy with a
+// `CapabilityPriorityConfigurable` queue is recommended.
const FCFSPolicyName = "FCFS"
func init() {
@@ -35,7 +61,9 @@ func init() {
})
}
-// fcfs (First-Come, First-Served) implements the `framework.IntraFlowDispatchPolicy` interface.
+// fcfs is the internal implementation of the FCFS policy.
+// See the documentation for the exported `FCFSPolicyName` constant for detailed user-facing information about its
+// behavior.
type fcfs struct {
comparator framework.ItemComparator
}
@@ -70,9 +98,10 @@ func (p *fcfs) Comparator() framework.ItemComparator {
return p.comparator
}
-// RequiredQueueCapabilities specifies that this policy needs a queue that supports FIFO operations.
+// RequiredQueueCapabilities returns an empty slice, indicating that this policy can operate with any queue.
+// See the `FCFSPolicyName` constant's documentation for details on the behavioral trade-offs.
func (p *fcfs) RequiredQueueCapabilities() []framework.QueueCapability {
- return []framework.QueueCapability{framework.CapabilityFIFO}
+ return []framework.QueueCapability{}
}
// --- enqueueTimeComparator ---
diff --git a/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs_test.go b/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs_test.go
index 45c144238..cc6bceecf 100644
--- a/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs_test.go
+++ b/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch/fcfs/fcfs_test.go
@@ -41,8 +41,7 @@ func TestFCFS_RequiredQueueCapabilities(t *testing.T) {
t.Parallel()
policy := newFCFS()
caps := policy.RequiredQueueCapabilities()
- require.Len(t, caps, 1, "RequiredQueueCapabilities should return one capability")
- assert.Equal(t, framework.CapabilityFIFO, caps[0], "Required capability should be FIFO")
+ require.Empty(t, caps, "No required capabilities should be returned")
}
func TestFCFS_SelectItem(t *testing.T) {
diff --git a/pkg/epp/flowcontrol/framework/plugins/queue/listqueue/listqueue.go b/pkg/epp/flowcontrol/framework/plugins/queue/listqueue/listqueue.go
index 8e123b631..792e3a46d 100644
--- a/pkg/epp/flowcontrol/framework/plugins/queue/listqueue/listqueue.go
+++ b/pkg/epp/flowcontrol/framework/plugins/queue/listqueue/listqueue.go
@@ -14,8 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
-// Package listqueue provides a simple, concurrent-safe queue implementation using a standard library
-// `container/list.List` as the underlying data structure for FIFO (First-In, First-Out) behavior.
+// Package listqueue provides a high-performance, concurrent-safe FIFO (First-In, First-Out) implementation of
+// implementation of the `framework.SafeQueue` based on the standard library's `container/list`.
package listqueue
import (
@@ -28,7 +28,28 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
)
-// ListQueueName is the name of the list queue implementation.
+// ListQueueName is the name of the list-based queue implementation.
+//
+// This queue provides a high-performance, low-overhead implementation based on a standard `container/list`.
+// It advertises the `CapabilityFIFO`.
+//
+// # Behavioral Guarantees
+//
+// The core guarantee of this queue is strict physical First-In, First-Out (FIFO) ordering. It processes items in the
+// exact order they are added to the queue on a specific shard.
+//
+// # Performance and Trade-offs
+//
+// Because the physical insertion order may not match a request's logical arrival time (due to the
+// `controller.FlowController`'s internal "bounce-and-retry" mechanic), this queue provides an*approximate FCFS behavior
+// from a system-wide perspective.
+//
+// Given that true end-to-end ordering is non-deterministic in a distributed system, this high-performance queue is the
+// recommended default for most FCFS-like policies. It prioritizes throughput and efficiency, which aligns with the
+// primary goal of the Flow Control system.
+//
+// For workloads that require the strictest possible logical-time ordering this layer can provide, explicitly using a
+// queue that supports `CapabilityPriorityConfigurable` is the appropriate choice.
const ListQueueName = "ListQueue"
func init() {
@@ -39,8 +60,8 @@ func init() {
})
}
-// listQueue implements the `framework.SafeQueue` interface using a standard `container/list.List` for FIFO behavior.
-// This implementation is concurrent-safe.
+// listQueue is the internal implementation of the ListQueue.
+// See the documentation for the exported `ListQueueName` constant for detailed user-facing information.
type listQueue struct {
requests *list.List
byteSize atomic.Uint64
diff --git a/pkg/epp/flowcontrol/framework/policies.go b/pkg/epp/flowcontrol/framework/policies.go
index 5fc3e646d..eeea034eb 100644
--- a/pkg/epp/flowcontrol/framework/policies.go
+++ b/pkg/epp/flowcontrol/framework/policies.go
@@ -168,7 +168,7 @@ type FlowQueueAccessor interface {
// Conformance: Implementations MUST ensure all methods are goroutine-safe for concurrent access.
type PriorityBandAccessor interface {
// Priority returns the numerical priority level of this band.
- Priority() uint
+ Priority() int
// PriorityName returns the human-readable name of this priority band.
PriorityName() string
diff --git a/pkg/epp/flowcontrol/registry/config.go b/pkg/epp/flowcontrol/registry/config.go
index d404e78f7..af345a665 100644
--- a/pkg/epp/flowcontrol/registry/config.go
+++ b/pkg/epp/flowcontrol/registry/config.go
@@ -98,7 +98,7 @@ type Config struct {
// the correct element within this specific configuration instance, preventing common "pointer-to-loop-variable"
// errors, especially across deep copies or partitioning.
// It is populated during validation and when the config is copied or partitioned.
- priorityBandMap map[uint]*PriorityBandConfig
+ priorityBandMap map[int]*PriorityBandConfig
// Factory functions used for plugin instantiation during configuration validation.
// These enable dependency injection for unit testing the validation logic.
@@ -112,9 +112,9 @@ type Config struct {
// that operate at this priority level.
type PriorityBandConfig struct {
// Priority is the unique numerical priority level for this band.
- // Convention: Lower numerical values indicate higher priority (e.g., 0 is highest).
+ // Convention: Highest numeric value corresponds to highest priority (centered on 0).
// Required.
- Priority uint
+ Priority int
// PriorityName is a human-readable name for this priority band (e.g., "Critical", "Standard").
// It must be unique across all priority bands in the configuration.
@@ -140,20 +140,6 @@ type PriorityBandConfig struct {
MaxBytes uint64
}
-// NewConfig performs validation and initialization, returning a guaranteed-valid `Config` object.
-// This is the required constructor for creating a new configuration. It applies provided functional options (primarily
-// for testing) and does not mutate the input `cfg`.
-func NewConfig(cfg Config, opts ...configOption) (*Config, error) {
- newCfg := cfg.deepCopy()
- for _, opt := range opts {
- opt(newCfg)
- }
- if err := newCfg.validateAndApplyDefaults(); err != nil {
- return nil, err
- }
- return newCfg, nil
-}
-
// =============================================================================
// Shard-Level Configuration
// =============================================================================
@@ -170,13 +156,13 @@ type ShardConfig struct {
// priorityBandMap provides O(1) lookups of `ShardPriorityBandConfig` by priority level.
// It serves as a correctness mechanism, ensuring that accessors return a safe, stable pointer to the correct element
// within this specific shard configuration instance.
- priorityBandMap map[uint]*ShardPriorityBandConfig
+ priorityBandMap map[int]*ShardPriorityBandConfig
}
// ShardPriorityBandConfig holds the partitioned configuration for a single priority band within a single shard.
type ShardPriorityBandConfig struct {
// Priority is the unique numerical priority level for this band.
- Priority uint
+ Priority int
// PriorityName is a unique human-readable name for this priority band.
PriorityName string
// IntraFlowDispatchPolicy is the name of the policy for dispatch within a flow's queue.
@@ -192,7 +178,7 @@ type ShardPriorityBandConfig struct {
// getBandConfig finds and returns the shard-level configuration for a specific priority level.
// Returns an error wrapping `contracts.ErrPriorityBandNotFound` if the priority is not configured.
-func (sc *ShardConfig) getBandConfig(priority uint) (*ShardPriorityBandConfig, error) {
+func (sc *ShardConfig) getBandConfig(priority int) (*ShardPriorityBandConfig, error) {
if band, ok := sc.priorityBandMap[priority]; ok {
return band, nil
}
@@ -205,52 +191,55 @@ func (sc *ShardConfig) getBandConfig(priority uint) (*ShardPriorityBandConfig, e
// --- Validation and Defaulting ---
-// validateAndApplyDefaults checks the global configuration for validity (including plugin compatibility) and mutates
-// the receiver to populate any empty fields with system defaults. It also initializes internal lookup maps.
-func (c *Config) validateAndApplyDefaults() error {
+// ValidateAndApplyDefaults checks the global configuration for validity and then creates a new `Config` object,
+// populating any empty fields with system defaults.
+// It does not mutate the receiver.
+func (c *Config) ValidateAndApplyDefaults() (*Config, error) {
+ cfg := c.deepCopy()
+
// Apply defaults to top-level fields.
- if c.InitialShardCount <= 0 {
- c.InitialShardCount = defaultInitialShardCount
+ if cfg.InitialShardCount <= 0 {
+ cfg.InitialShardCount = defaultInitialShardCount
}
- if c.FlowGCTimeout <= 0 {
- c.FlowGCTimeout = defaultFlowGCTimeout
+ if cfg.FlowGCTimeout <= 0 {
+ cfg.FlowGCTimeout = defaultFlowGCTimeout
}
- if c.EventChannelBufferSize <= 0 {
- c.EventChannelBufferSize = defaultEventChannelBufferSize
+ if cfg.EventChannelBufferSize <= 0 {
+ cfg.EventChannelBufferSize = defaultEventChannelBufferSize
}
// Ensure the DI factories are initialized for production use if `NewConfig` was called without options.
- if c.interFlowDispatchPolicyFactory == nil {
- c.interFlowDispatchPolicyFactory = inter.NewPolicyFromName
+ if cfg.interFlowDispatchPolicyFactory == nil {
+ cfg.interFlowDispatchPolicyFactory = inter.NewPolicyFromName
}
- if c.intraFlowDispatchPolicyFactory == nil {
- c.intraFlowDispatchPolicyFactory = intra.NewPolicyFromName
+ if cfg.intraFlowDispatchPolicyFactory == nil {
+ cfg.intraFlowDispatchPolicyFactory = intra.NewPolicyFromName
}
- if c.queueFactory == nil {
- c.queueFactory = queue.NewQueueFromName
+ if cfg.queueFactory == nil {
+ cfg.queueFactory = queue.NewQueueFromName
}
- if len(c.PriorityBands) == 0 {
- return errors.New("config validation failed: at least one priority band must be defined")
+ if len(cfg.PriorityBands) == 0 {
+ return nil, errors.New("config validation failed: at least one priority band must be defined")
}
// Validate and default each priority band.
- priorities := make(map[uint]struct{})
+ priorities := make(map[int]struct{})
priorityNames := make(map[string]struct{})
- c.priorityBandMap = make(map[uint]*PriorityBandConfig, len(c.PriorityBands))
+ cfg.priorityBandMap = make(map[int]*PriorityBandConfig, len(cfg.PriorityBands))
- for i := range c.PriorityBands {
- band := &c.PriorityBands[i]
+ for i := range cfg.PriorityBands {
+ band := &cfg.PriorityBands[i]
if _, exists := priorities[band.Priority]; exists {
- return fmt.Errorf("config validation failed: duplicate priority level %d found", band.Priority)
+ return nil, fmt.Errorf("config validation failed: duplicate priority level %d found", band.Priority)
}
priorities[band.Priority] = struct{}{}
if band.PriorityName == "" {
- return fmt.Errorf("config validation failed: PriorityName is required for priority band %d", band.Priority)
+ return nil, fmt.Errorf("config validation failed: PriorityName is required for priority band %d", band.Priority)
}
if _, exists := priorityNames[band.PriorityName]; exists {
- return fmt.Errorf("config validation failed: duplicate priority name %q found", band.PriorityName)
+ return nil, fmt.Errorf("config validation failed: duplicate priority name %q found", band.PriorityName)
}
priorityNames[band.PriorityName] = struct{}{}
@@ -267,12 +256,12 @@ func (c *Config) validateAndApplyDefaults() error {
band.MaxBytes = defaultPriorityBandMaxBytes
}
- if err := c.validateBandCompatibility(*band); err != nil {
- return err
+ if err := cfg.validateBandCompatibility(*band); err != nil {
+ return nil, err
}
- c.priorityBandMap[band.Priority] = band
+ cfg.priorityBandMap[band.Priority] = band
}
- return nil
+ return cfg, nil
}
// validateBandCompatibility verifies that a band's configured queue type has the necessary capabilities.
@@ -326,7 +315,7 @@ func (c *Config) partition(shardIndex, totalShards int) *ShardConfig {
shardCfg := &ShardConfig{
MaxBytes: partitionUint64(c.MaxBytes, shardIndex, totalShards),
PriorityBands: make([]ShardPriorityBandConfig, len(c.PriorityBands)),
- priorityBandMap: make(map[uint]*ShardPriorityBandConfig, len(c.PriorityBands)),
+ priorityBandMap: make(map[int]*ShardPriorityBandConfig, len(c.PriorityBands)),
}
for i, template := range c.PriorityBands {
@@ -423,6 +412,18 @@ func withQueueFactory(factory queueFactory) configOption {
}
}
+// newConfig creates a new validated and defaulted `Config` object.
+// It applies provided test-only functional options before validation and defaulting.
+// It does not mutate the input `cfg`.
+// test-only
+func newConfig(cfg Config, opts ...configOption) (*Config, error) {
+ newCfg := cfg.deepCopy()
+ for _, opt := range opts {
+ opt(newCfg)
+ }
+ return newCfg.ValidateAndApplyDefaults()
+}
+
// --- Internal Utilities ---
// deepCopy creates a deep copy of the `Config` object.
@@ -436,7 +437,6 @@ func (c *Config) deepCopy() *Config {
FlowGCTimeout: c.FlowGCTimeout,
EventChannelBufferSize: c.EventChannelBufferSize,
PriorityBands: make([]PriorityBandConfig, len(c.PriorityBands)),
- priorityBandMap: make(map[uint]*PriorityBandConfig, len(c.PriorityBands)),
interFlowDispatchPolicyFactory: c.interFlowDispatchPolicyFactory,
intraFlowDispatchPolicyFactory: c.intraFlowDispatchPolicyFactory,
queueFactory: c.queueFactory,
@@ -445,18 +445,21 @@ func (c *Config) deepCopy() *Config {
// PriorityBandConfig contains only value types, so a slice copy is sufficient for a deep copy.
copy(newCfg.PriorityBands, c.PriorityBands)
- // Crucial: We must rebuild the map and take the address of the elements within the new slice (`newCfg.PriorityBands`)
- // to ensure the map pointers are correct for the newly created `Config` instance.
- for i := range newCfg.PriorityBands {
- band := &newCfg.PriorityBands[i]
- newCfg.priorityBandMap[band.Priority] = band
+ if c.priorityBandMap != nil {
+ newCfg.priorityBandMap = make(map[int]*PriorityBandConfig, len(c.PriorityBands))
+ // Crucial: We must rebuild the map and take the address of the elements within the new slice (`newCfg.PriorityBands`)
+ // to ensure the map pointers are correct for the newly created `Config` instance.
+ for i := range newCfg.PriorityBands {
+ band := &newCfg.PriorityBands[i]
+ newCfg.priorityBandMap[band.Priority] = band
+ }
}
return newCfg
}
// getBandConfig finds and returns the global configuration template for a specific priority level.
// Returns an error wrapping `contracts.ErrPriorityBandNotFound` if the priority is not configured.
-func (c *Config) getBandConfig(priority uint) (*PriorityBandConfig, error) {
+func (c *Config) getBandConfig(priority int) (*PriorityBandConfig, error) {
if band, ok := c.priorityBandMap[priority]; ok {
return band, nil
}
diff --git a/pkg/epp/flowcontrol/registry/config_test.go b/pkg/epp/flowcontrol/registry/config_test.go
index 376282ffe..47814ae6e 100644
--- a/pkg/epp/flowcontrol/registry/config_test.go
+++ b/pkg/epp/flowcontrol/registry/config_test.go
@@ -35,7 +35,7 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue/listqueue"
)
-func TestConfig_NewConfig(t *testing.T) {
+func TestConfig_ValidateAndApplyDefaults(t *testing.T) {
t.Parallel()
testCases := []struct {
@@ -211,16 +211,26 @@ func TestConfig_NewConfig(t *testing.T) {
name: "ShouldError_WhenQueueFactoryFails",
input: Config{
PriorityBands: []PriorityBandConfig{{
- Priority: 1,
- PriorityName: "High",
- Queue: queue.RegisteredQueueName("failing-queue"),
+ Priority: 1,
+ PriorityName: "High",
+ Queue: queue.RegisteredQueueName("failing-queue"),
+ IntraFlowDispatchPolicy: intra.RegisteredPolicyName("policy-with-req"),
}},
},
expectErr: true,
- opts: []configOption{withQueueFactory(
- func(_ queue.RegisteredQueueName, _ framework.ItemComparator) (framework.SafeQueue, error) {
+ opts: []configOption{
+ withIntraFlowDispatchPolicyFactory( // Forces queue instance creation for validating capabilities.
+ func(name intra.RegisteredPolicyName) (framework.IntraFlowDispatchPolicy, error) {
+ return &mocks.MockIntraFlowDispatchPolicy{
+ NameV: string(name),
+ RequiredQueueCapabilitiesV: []framework.QueueCapability{"required-capability"},
+ }, nil
+ },
+ ),
+ withQueueFactory(func(_ queue.RegisteredQueueName, _ framework.ItemComparator) (framework.SafeQueue, error) {
return nil, errors.New("queue creation failed")
- })},
+ }),
+ },
},
{
name: "ShouldError_WhenPolicyFactoryFails",
@@ -276,20 +286,24 @@ func TestConfig_NewConfig(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
- originalInputCopy := tc.input.deepCopy()
- newCfg, err := NewConfig(tc.input, tc.opts...)
+ originalInput := tc.input.deepCopy()
+ validatedCfg, err := newConfig(tc.input, tc.opts...)
+
if tc.expectErr {
- require.Error(t, err, "NewConfig should have returned an error")
+ require.Error(t, err, "expected an error but got nil")
if tc.expectedErrIs != nil {
- assert.ErrorIs(t, err, tc.expectedErrIs, "Error should wrap the expected error type")
+ assert.ErrorIs(t, err, tc.expectedErrIs, "error should wrap the expected error type")
}
- assert.Nil(t, newCfg, "On error, the returned config should be nil")
+ assert.Nil(t, validatedCfg, "validatedCfg should be nil on error")
} else {
- require.NoError(t, err, "NewConfig should not have returned an error")
- require.NotNil(t, newCfg, "On success, the returned config should not be nil")
+ require.NoError(t, err, "expected no error but got: %v", err)
+ require.NotNil(t, validatedCfg, "validatedCfg should not be nil on success")
if tc.assertion != nil {
- tc.assertion(t, *originalInputCopy, newCfg)
+ tc.assertion(t, *originalInput, validatedCfg)
}
+
+ // Ensure the original config is not mutated.
+ assert.Equal(t, *originalInput, tc.input, "input config should not be mutated")
}
})
}
@@ -297,7 +311,7 @@ func TestConfig_NewConfig(t *testing.T) {
func TestConfig_Partition(t *testing.T) {
t.Parallel()
- baseCfg, err := NewConfig(Config{
+ baseCfg, err := newConfig(Config{
MaxBytes: 103, // Will not distribute evenly
PriorityBands: []PriorityBandConfig{
{Priority: 1, PriorityName: "High", MaxBytes: 55}, // Will not distribute evenly
@@ -381,7 +395,7 @@ func TestConfig_Partition(t *testing.T) {
func TestConfig_GetBandConfig(t *testing.T) {
t.Parallel()
- cfg, err := NewConfig(Config{
+ cfg, err := newConfig(Config{
PriorityBands: []PriorityBandConfig{
{Priority: 10, PriorityName: "High"},
},
@@ -417,7 +431,7 @@ func TestConfig_DeepCopy(t *testing.T) {
},
}
// Create a fully initialized "original" config to be the source of the copy.
- original, err := NewConfig(baseCfg)
+ original, err := newConfig(baseCfg)
require.NoError(t, err, "Setup for deep copy should not fail")
t.Run("ShouldReturnNil_ForNilReceiver", func(t *testing.T) {
@@ -471,7 +485,7 @@ func TestConfig_DeepCopy(t *testing.T) {
func TestShardConfig_GetBandConfig(t *testing.T) {
t.Parallel()
- baseCfg, err := NewConfig(Config{
+ baseCfg, err := newConfig(Config{
PriorityBands: []PriorityBandConfig{
{Priority: 10, PriorityName: "High"},
{Priority: 20, PriorityName: "Low"},
diff --git a/pkg/epp/flowcontrol/registry/connection.go b/pkg/epp/flowcontrol/registry/connection.go
index 995f23c13..cb9831655 100644
--- a/pkg/epp/flowcontrol/registry/connection.go
+++ b/pkg/epp/flowcontrol/registry/connection.go
@@ -31,13 +31,13 @@ type connection struct {
var _ contracts.ActiveFlowConnection = &connection{}
// Shards returns a stable snapshot of accessors for all internal state shards.
-func (c *connection) Shards() []contracts.RegistryShard {
+func (c *connection) ActiveShards() []contracts.RegistryShard {
c.registry.mu.RLock()
defer c.registry.mu.RUnlock()
// Return a copy to ensure the caller cannot modify the registry's internal slice.
- shardsCopy := make([]contracts.RegistryShard, len(c.registry.allShards))
- for i, s := range c.registry.allShards {
+ shardsCopy := make([]contracts.RegistryShard, len(c.registry.activeShards))
+ for i, s := range c.registry.activeShards {
shardsCopy[i] = s
}
return shardsCopy
diff --git a/pkg/epp/flowcontrol/registry/managedqueue_test.go b/pkg/epp/flowcontrol/registry/managedqueue_test.go
index f5e3d7fa7..64e5ab80c 100644
--- a/pkg/epp/flowcontrol/registry/managedqueue_test.go
+++ b/pkg/epp/flowcontrol/registry/managedqueue_test.go
@@ -98,7 +98,7 @@ type mockStatsPropagator struct {
byteSizeDelta atomic.Int64
}
-func (p *mockStatsPropagator) propagate(_ uint, lenDelta, byteSizeDelta int64) {
+func (p *mockStatsPropagator) propagate(_ int, lenDelta, byteSizeDelta int64) {
p.lenDelta.Add(lenDelta)
p.byteSizeDelta.Add(byteSizeDelta)
}
diff --git a/pkg/epp/flowcontrol/registry/registry.go b/pkg/epp/flowcontrol/registry/registry.go
index 95d604ede..3a73ef706 100644
--- a/pkg/epp/flowcontrol/registry/registry.go
+++ b/pkg/epp/flowcontrol/registry/registry.go
@@ -37,7 +37,7 @@ import (
// propagateStatsDeltaFunc defines the callback function used to propagate statistics changes (deltas) up the hierarchy
// (Queue -> Shard -> Registry).
// Implementations MUST be non-blocking (relying on atomics).
-type propagateStatsDeltaFunc func(priority uint, lenDelta, byteSizeDelta int64)
+type propagateStatsDeltaFunc func(priority int, lenDelta, byteSizeDelta int64)
// bandStats holds the aggregated atomic statistics for a single priority band across all shards.
type bandStats struct {
@@ -120,7 +120,7 @@ type FlowRegistry struct {
// Globally aggregated statistics, updated atomically via lock-free propagation.
totalByteSize atomic.Int64
totalLen atomic.Int64
- perPriorityBandStats map[uint]*bandStats // Keyed by priority.
+ perPriorityBandStats map[int]*bandStats // Keyed by priority.
// --- Administrative state (protected by `mu`) ---
@@ -148,17 +148,13 @@ func withClock(clk clock.WithTickerAndDelayedExecution) RegistryOption {
// NewFlowRegistry creates and initializes a new `FlowRegistry` instance.
func NewFlowRegistry(config Config, logger logr.Logger, opts ...RegistryOption) (*FlowRegistry, error) {
- validatedConfig, err := NewConfig(config)
- if err != nil {
- return nil, fmt.Errorf("master configuration is invalid: %w", err)
- }
-
+ cfg := config.deepCopy()
fr := &FlowRegistry{
- config: validatedConfig,
+ config: cfg,
logger: logger.WithName("flow-registry"),
activeShards: []*registryShard{},
drainingShards: make(map[string]*registryShard),
- perPriorityBandStats: make(map[uint]*bandStats, len(validatedConfig.PriorityBands)),
+ perPriorityBandStats: make(map[int]*bandStats, len(cfg.PriorityBands)),
}
for _, opt := range opts {
@@ -173,7 +169,7 @@ func NewFlowRegistry(config Config, logger logr.Logger, opts ...RegistryOption)
fr.perPriorityBandStats[band.Priority] = &bandStats{}
}
- if err := fr.updateShardCount(validatedConfig.InitialShardCount); err != nil {
+ if err := fr.updateShardCount(cfg.InitialShardCount); err != nil {
return nil, fmt.Errorf("failed to initialize shards: %w", err)
}
fr.logger.V(logging.DEFAULT).Info("FlowRegistry initialized successfully")
@@ -198,7 +194,7 @@ func (fr *FlowRegistry) Run(ctx context.Context) {
}
}
-// --- `contracts.FlowRegistryClient` Implementation ---
+// --- `contracts.FlowRegistryDataPlane` Implementation ---
// Connect establishes a session for a given flow, acquiring a lifecycle lease.
// This is the primary entry point for the data path.
@@ -275,7 +271,7 @@ func (fr *FlowRegistry) prepareNewFlow(key types.FlowKey) (*flowState, error) {
return &flowState{key: key}, nil
}
-// --- `contracts.FlowRegistryAdmin` Implementation ---
+// --- `contracts.FlowRegistryObserver` Implementation ---
// Stats returns globally aggregated statistics for the entire `FlowRegistry`.
//
@@ -289,7 +285,7 @@ func (fr *FlowRegistry) Stats() contracts.AggregateStats {
TotalCapacityBytes: fr.config.MaxBytes,
TotalByteSize: uint64(fr.totalByteSize.Load()),
TotalLen: uint64(fr.totalLen.Load()),
- PerPriorityBandStats: make(map[uint]contracts.PriorityBandStats, len(fr.config.PriorityBands)),
+ PerPriorityBandStats: make(map[int]contracts.PriorityBandStats, len(fr.config.PriorityBands)),
}
for p, s := range fr.perPriorityBandStats {
@@ -592,7 +588,7 @@ func (fr *FlowRegistry) updateAllShardsCacheLocked() {
}
// propagateStatsDelta is the top-level, lock-free aggregator for all statistics.
-func (fr *FlowRegistry) propagateStatsDelta(priority uint, lenDelta, byteSizeDelta int64) {
+func (fr *FlowRegistry) propagateStatsDelta(priority int, lenDelta, byteSizeDelta int64) {
stats, ok := fr.perPriorityBandStats[priority]
if !ok {
panic(fmt.Sprintf("invariant violation: priority band (%d) stats missing during propagation", priority))
diff --git a/pkg/epp/flowcontrol/registry/registry_test.go b/pkg/epp/flowcontrol/registry/registry_test.go
index 5ffa600d0..b5bc322cb 100644
--- a/pkg/epp/flowcontrol/registry/registry_test.go
+++ b/pkg/epp/flowcontrol/registry/registry_test.go
@@ -73,9 +73,12 @@ func newRegistryTestHarness(t *testing.T, opts harnessOptions) *registryTestHarn
config.InitialShardCount = opts.initialShardCount
}
+ validatedCfg, err := config.ValidateAndApplyDefaults()
+ require.NoError(t, err, "Test setup: validating config should not fail")
+
fakeClock := testclock.NewFakeClock(time.Now())
registryOpts := []RegistryOption{withClock(fakeClock)}
- fr, err := NewFlowRegistry(config, logr.Discard(), registryOpts...)
+ fr, err := NewFlowRegistry(*validatedCfg, logr.Discard(), registryOpts...)
require.NoError(t, err, "Test setup: NewFlowRegistry should not fail")
// Start the GC loop in the background.
@@ -132,69 +135,9 @@ func (h *registryTestHarness) openConnectionOnFlow(key types.FlowKey) {
func TestFlowRegistry_New(t *testing.T) {
t.Parallel()
- t.Run("ShouldApplyDefaults_WhenInitialized", func(t *testing.T) {
- t.Parallel()
- config := Config{PriorityBands: []PriorityBandConfig{{Priority: highPriority, PriorityName: "DefaultedBand"}}}
- fr, err := NewFlowRegistry(config, logr.Discard())
- require.NoError(t, err, "Creating a valid registry with defaults should not fail")
- assert.Equal(t, defaultInitialShardCount, fr.config.InitialShardCount, "InitialShardCount should be defaulted")
- assert.Equal(t, defaultFlowGCTimeout, fr.config.FlowGCTimeout, "FlowGCTimeout should be defaulted")
- assert.Equal(t, defaultEventChannelBufferSize, fr.config.EventChannelBufferSize,
- "EventChannelBufferSize should be defaulted")
- assert.Len(t, fr.allShards, defaultInitialShardCount,
- "Registry should be initialized with the default number of shards")
- bandConf, err := fr.config.getBandConfig(highPriority)
- require.NoError(t, err, "Getting the defaulted band config should not fail")
- assert.Equal(t, defaultPriorityBandMaxBytes, bandConf.MaxBytes, "Priority band MaxBytes should be defaulted")
- })
-
- t.Run("ShouldFail_OnInvalidConfiguration", func(t *testing.T) {
- t.Parallel()
- testCases := []struct {
- name string
- config Config
- expectErrSubStr string
- }{
- {
- name: "WhenNoPriorityBandsAreDefined",
- config: Config{},
- expectErrSubStr: "at least one priority band must be defined",
- },
- {
- name: "WhenPriorityLevelsAreDuplicated",
- config: Config{
- PriorityBands: []PriorityBandConfig{
- {Priority: highPriority, PriorityName: "A"},
- {Priority: highPriority, PriorityName: "B"},
- },
- },
- expectErrSubStr: fmt.Sprintf("duplicate priority level %d", highPriority),
- },
- {
- name: "WhenPriorityNamesAreDuplicated",
- config: Config{
- PriorityBands: []PriorityBandConfig{
- {Priority: highPriority, PriorityName: "A"},
- {Priority: lowPriority, PriorityName: "A"},
- },
- },
- expectErrSubStr: `duplicate priority name "A"`,
- },
- }
-
- for _, tc := range testCases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- _, err := NewFlowRegistry(tc.config, logr.Discard())
- require.Error(t, err, "NewFlowRegistry should fail with an invalid config")
- assert.Contains(t, err.Error(), tc.expectErrSubStr, "Error message should contain the expected reason")
- })
- }
- })
-
t.Run("ShouldFail_WhenInitialShardCreationFails", func(t *testing.T) {
t.Parallel()
- config, err := NewConfig(
+ config, err := newConfig(
Config{PriorityBands: []PriorityBandConfig{{Priority: highPriority, PriorityName: "A"}}},
withInterFlowDispatchPolicyFactory(func(inter.RegisteredPolicyName) (framework.InterFlowDispatchPolicy, error) {
return nil, errors.New("injected factory failure")
@@ -261,7 +204,7 @@ func TestFlowRegistry_WithConnection_AndHandle(t *testing.T) {
assert.ErrorContains(t, err, "injected factory failure", "The returned error must propagate the reason")
})
- t.Run("Handle_Shards_ShouldReturnAllShardsAndBeACopy", func(t *testing.T) {
+ t.Run("Handle_Shards_ShouldReturnAllActiveShardsAndBeACopy", func(t *testing.T) {
t.Parallel()
// Create a registry with a known mixed topology of Active and Draining shards.
h := newRegistryTestHarness(t, harnessOptions{initialShardCount: 3})
@@ -272,9 +215,9 @@ func TestFlowRegistry_WithConnection_AndHandle(t *testing.T) {
key := types.FlowKey{ID: "test-flow", Priority: highPriority}
err = h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error {
- shards := conn.Shards()
+ shards := conn.ActiveShards()
- assert.Len(t, shards, 3, "Shards() must return all configured shards, including Draining ones")
+ assert.Len(t, shards, 2, "ActiveShards() must only return the Active shards")
// Assert it's a copy by maliciously modifying it.
require.NotEmpty(t, shards, "Test setup assumes shards are present")
@@ -285,8 +228,8 @@ func TestFlowRegistry_WithConnection_AndHandle(t *testing.T) {
require.NoError(t, err)
// Prove the registry's internal state was not mutated by the modification.
- assert.NotNil(t, h.fr.allShards[0],
- "Modifying the slice returned by Shards() must not affect the registry's internal state")
+ assert.NotNil(t, h.fr.activeShards[0],
+ "Modifying the slice returned by ActiveShards() must not affect the registry's internal state")
})
}
@@ -323,6 +266,9 @@ func TestFlowRegistry_Stats(t *testing.T) {
require.Len(t, shardStats, 2, "Should return stats for 2 shards")
var totalShardLen, totalShardBytes uint64
for _, ss := range shardStats {
+ assert.True(t, ss.IsActive, "All shards should be active in this test")
+ assert.NotEmpty(t, ss.PerPriorityBandStats, "Each shard should have stats for its priority bands")
+ assert.NotEmpty(t, ss.ID, "Each shard should have a non-empty ID")
totalShardLen += ss.TotalLen
totalShardBytes += ss.TotalByteSize
}
@@ -541,14 +487,6 @@ func TestFlowRegistry_UpdateShardCount(t *testing.T) {
expectedPartitionedGlobalCapacities: map[uint64]int{25: 4},
expectedPartitionedBandCapacities: map[uint64]int{12: 2, 13: 2},
},
- {
- name: "Succeeds_ScaleUp_FromZero",
- initialShardCount: 0,
- targetShardCount: 4,
- expectedActiveCount: 4,
- expectedPartitionedGlobalCapacities: map[uint64]int{25: 4},
- expectedPartitionedBandCapacities: map[uint64]int{12: 2, 13: 2},
- },
{
name: "Succeeds_ScaleDown_ToOne",
initialShardCount: 3,
@@ -589,7 +527,7 @@ func TestFlowRegistry_UpdateShardCount(t *testing.T) {
}
h := newRegistryTestHarness(t, harnessOptions{config: &config})
- key := types.FlowKey{ID: "flow", Priority: 10}
+ key := types.FlowKey{ID: "flow", Priority: highPriority}
h.openConnectionOnFlow(key)
err := h.fr.updateShardCount(tc.targetShardCount)
@@ -600,24 +538,19 @@ func TestFlowRegistry_UpdateShardCount(t *testing.T) {
require.NoError(t, err, "UpdateShardCount should not have returned an error")
}
- var finalActiveCount, finalDrainingCount int
globalCapacities := make(map[uint64]int)
bandCapacities := make(map[uint64]int)
- err = h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error {
- for _, shard := range conn.Shards() {
- if shard.IsActive() {
- finalActiveCount++
- stats := shard.Stats()
- globalCapacities[stats.TotalCapacityBytes]++
- bandCapacities[stats.PerPriorityBandStats[highPriority].CapacityBytes]++
- h.assertFlowExists(key, "Shard %s should contain the existing flow", shard.ID())
- } else {
- finalDrainingCount++
- }
- }
- return nil
- })
- require.NoError(t, err, "WithConnection should not fail")
+
+ h.fr.mu.RLock()
+ finalActiveCount := len(h.fr.activeShards)
+ finalDrainingCount := len(h.fr.drainingShards)
+ for _, shard := range h.fr.activeShards {
+ stats := shard.Stats()
+ globalCapacities[stats.TotalCapacityBytes]++
+ bandCapacities[stats.PerPriorityBandStats[highPriority].CapacityBytes]++
+ h.assertFlowExists(key, "Shard %s should contain the existing flow", shard.ID())
+ }
+ h.fr.mu.RUnlock()
expectedDrainingCount := 0
if tc.initialShardCount > tc.expectedActiveCount {
diff --git a/pkg/epp/flowcontrol/registry/shard.go b/pkg/epp/flowcontrol/registry/shard.go
index 36032e42c..4a7918e79 100644
--- a/pkg/epp/flowcontrol/registry/shard.go
+++ b/pkg/epp/flowcontrol/registry/shard.go
@@ -18,7 +18,7 @@ package registry
import (
"fmt"
- "slices"
+ "sort"
"sync"
"sync/atomic"
@@ -76,7 +76,7 @@ type registryShard struct {
// onStatsDelta is the callback used to propagate statistics changes up to the parent registry.
onStatsDelta propagateStatsDeltaFunc
// orderedPriorityLevels is a cached, sorted list of priority levels.
- orderedPriorityLevels []uint
+ orderedPriorityLevels []int
// --- State Protected by `mu` ---
@@ -88,7 +88,7 @@ type registryShard struct {
// config holds the partitioned configuration for this shard, derived from the `FlowRegistry`'s global `Config`.
config *ShardConfig
// priorityBands is the primary lookup table for all managed queues on this shard.
- priorityBands map[uint]*priorityBand
+ priorityBands map[int]*priorityBand
// --- Concurrent-Safe State (Atomics) ---
@@ -116,7 +116,7 @@ func newShard(
logger: shardLogger,
config: config,
onStatsDelta: onStatsDelta,
- priorityBands: make(map[uint]*priorityBand, len(config.PriorityBands)),
+ priorityBands: make(map[int]*priorityBand, len(config.PriorityBands)),
}
for _, bandConfig := range config.PriorityBands {
@@ -133,8 +133,9 @@ func newShard(
}
s.orderedPriorityLevels = append(s.orderedPriorityLevels, bandConfig.Priority)
}
-
- slices.Sort(s.orderedPriorityLevels)
+ sort.Slice(s.orderedPriorityLevels, func(i, j int) bool {
+ return s.orderedPriorityLevels[i] > s.orderedPriorityLevels[j]
+ })
s.logger.V(logging.DEFAULT).Info("Registry shard initialized successfully",
"priorityBandCount", len(s.priorityBands), "orderedPriorities", s.orderedPriorityLevels)
return s, nil
@@ -184,7 +185,7 @@ func (s *registryShard) IntraFlowDispatchPolicy(key types.FlowKey) (framework.In
// InterFlowDispatchPolicy retrieves a priority band's configured `framework.InterFlowDispatchPolicy`.
// This read is lock-free as the policy instance is immutable after the shard is initialized.
-func (s *registryShard) InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error) {
+func (s *registryShard) InterFlowDispatchPolicy(priority int) (framework.InterFlowDispatchPolicy, error) {
// This read is safe because the `priorityBands` map structure is immutable after initialization.
band, ok := s.priorityBands[priority]
if !ok {
@@ -195,7 +196,7 @@ func (s *registryShard) InterFlowDispatchPolicy(priority uint) (framework.InterF
}
// PriorityBandAccessor retrieves a read-only view for a given priority level.
-func (s *registryShard) PriorityBandAccessor(priority uint) (framework.PriorityBandAccessor, error) {
+func (s *registryShard) PriorityBandAccessor(priority int) (framework.PriorityBandAccessor, error) {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -209,7 +210,7 @@ func (s *registryShard) PriorityBandAccessor(priority uint) (framework.PriorityB
// AllOrderedPriorityLevels returns a cached, sorted slice of all configured priority levels for this shard.
// This is a lock-free read.
-func (s *registryShard) AllOrderedPriorityLevels() []uint {
+func (s *registryShard) AllOrderedPriorityLevels() []int {
return s.orderedPriorityLevels
}
@@ -227,10 +228,12 @@ func (s *registryShard) Stats() contracts.ShardStats {
// Casts from `int64` to `uint64` are safe because the non-negative invariant is strictly enforced at the
// `managedQueue` level.
stats := contracts.ShardStats{
+ ID: s.id,
+ IsActive: s.IsActive(),
TotalCapacityBytes: s.config.MaxBytes,
TotalByteSize: uint64(s.totalByteSize.Load()),
TotalLen: uint64(s.totalLen.Load()),
- PerPriorityBandStats: make(map[uint]contracts.PriorityBandStats, len(s.priorityBands)),
+ PerPriorityBandStats: make(map[int]contracts.PriorityBandStats, len(s.priorityBands)),
}
for priority, band := range s.priorityBands {
@@ -325,7 +328,7 @@ func (s *registryShard) updateConfig(newConfig *ShardConfig) {
// propagateStatsDelta is the single point of entry for all statistics changes within the shard.
// It atomically updates the relevant band's stats, the shard's total stats, and propagates the delta to the parent
// registry.
-func (s *registryShard) propagateStatsDelta(priority uint, lenDelta, byteSizeDelta int64) {
+func (s *registryShard) propagateStatsDelta(priority int, lenDelta, byteSizeDelta int64) {
// This read is safe because the `priorityBands` map structure is immutable after initialization.
band, ok := s.priorityBands[priority]
if !ok {
@@ -355,7 +358,7 @@ type priorityBandAccessor struct {
var _ framework.PriorityBandAccessor = &priorityBandAccessor{}
// Priority returns the numerical priority level of this band.
-func (a *priorityBandAccessor) Priority() uint { return a.band.config.Priority }
+func (a *priorityBandAccessor) Priority() int { return a.band.config.Priority }
// PriorityName returns the human-readable name of this priority band.
func (a *priorityBandAccessor) PriorityName() string { return a.band.config.PriorityName }
diff --git a/pkg/epp/flowcontrol/registry/shard_test.go b/pkg/epp/flowcontrol/registry/shard_test.go
index 214497e41..23bf81325 100644
--- a/pkg/epp/flowcontrol/registry/shard_test.go
+++ b/pkg/epp/flowcontrol/registry/shard_test.go
@@ -37,11 +37,11 @@ import (
const (
// highPriority is the priority level for the "High" priority band in the test harness config.
- highPriority uint = 10
+ highPriority int = 20
// lowPriority is the priority level for the "Low" priority band in the test harness config.
- lowPriority uint = 20
+ lowPriority int = 10
// nonExistentPriority is a priority that is known not to exist in the test harness config.
- nonExistentPriority uint = 99
+ nonExistentPriority int = 99
)
// --- Test Harness and Mocks ---
@@ -59,7 +59,7 @@ type shardTestHarness struct {
// newShardTestHarness initializes a `shardTestHarness` with a default configuration.
func newShardTestHarness(t *testing.T) *shardTestHarness {
t.Helper()
- globalConfig, err := NewConfig(Config{
+ globalConfig, err := newConfig(Config{
PriorityBands: []PriorityBandConfig{
{Priority: highPriority, PriorityName: "High"},
{Priority: lowPriority, PriorityName: "Low"},
@@ -133,7 +133,7 @@ func TestShard_New(t *testing.T) {
assert.Equal(t, "test-shard-1", h.shard.ID(), "Shard ID must match the value provided during construction")
assert.True(t, h.shard.IsActive(), "A newly created shard must be initialized in the Active state")
- assert.Equal(t, []uint{highPriority, lowPriority}, h.shard.AllOrderedPriorityLevels(),
+ assert.Equal(t, []int{highPriority, lowPriority}, h.shard.AllOrderedPriorityLevels(),
"Shard must report configured priority levels sorted numerically (highest priority first)")
bandHigh, ok := h.shard.priorityBands[highPriority]
@@ -146,7 +146,7 @@ func TestShard_New(t *testing.T) {
t.Run("ShouldFail_WhenInterFlowPolicyFactoryFails", func(t *testing.T) {
t.Parallel()
- shardConfig, _ := NewConfig(Config{PriorityBands: []PriorityBandConfig{
+ shardConfig, _ := newConfig(Config{PriorityBands: []PriorityBandConfig{
{Priority: highPriority, PriorityName: "High"},
}})
failingFactory := func(inter.RegisteredPolicyName) (framework.InterFlowDispatchPolicy, error) {
@@ -165,6 +165,8 @@ func TestShard_Stats(t *testing.T) {
stats := h.shard.Stats()
+ assert.Equal(t, h.shard.ID(), stats.ID, "Stats ID must match the shard ID")
+ assert.True(t, stats.IsActive, "Shard must report itself as active in the stats snapshot")
assert.Equal(t, uint64(2), stats.TotalLen, "Total shard length must aggregate counts from all bands")
assert.Equal(t, uint64(150), stats.TotalByteSize, "Total shard byte size must aggregate sizes from all bands")
diff --git a/pkg/epp/flowcontrol/types/errors.go b/pkg/epp/flowcontrol/types/errors.go
index f7dffbd4d..8c966bb45 100644
--- a/pkg/epp/flowcontrol/types/errors.go
+++ b/pkg/epp/flowcontrol/types/errors.go
@@ -43,11 +43,8 @@ var (
// The following errors can occur before a request is formally added to a `framework.SafeQueue`. When returned by
// `FlowController.EnqueueAndWait()`, these specific errors will typically be wrapped by `ErrRejected`.
var (
- // ErrNilRequest indicates that a nil `types.FlowControlRequest` was provided.
- ErrNilRequest = errors.New("FlowControlRequest cannot be nil")
-
// ErrQueueAtCapacity indicates that a request could not be enqueued because queue capacity limits were met.
- ErrQueueAtCapacity = errors.New("queue at capacity and displacement failed to make space")
+ ErrQueueAtCapacity = errors.New("queue at capacity")
)
// --- Post-Enqueue Eviction Errors ---
@@ -68,10 +65,10 @@ var (
// --- General `controller.FlowController` Errors ---
var (
- // ErrFlowControllerShutdown indicates that an operation could not complete or an item was evicted because the
- // `controller.FlowController` is shutting down or has stopped.
+ // ErrFlowControllerNotRunning indicates that an operation could not complete or an item was evicted because the
+ // `controller.FlowController` is not running or is in the process of shutting down.
//
// When returned by `FlowController.EnqueueAndWait()`, this will be wrapped by `ErrRejected` (if rejection happens
// before internal queuing) or `ErrEvicted` (if eviction happens after internal queuing).
- ErrFlowControllerShutdown = errors.New("FlowController is shutting down")
+ ErrFlowControllerNotRunning = errors.New("flow controller is not running")
)
diff --git a/pkg/epp/flowcontrol/types/flow.go b/pkg/epp/flowcontrol/types/flow.go
index 71017c13c..2af2d5bd0 100644
--- a/pkg/epp/flowcontrol/types/flow.go
+++ b/pkg/epp/flowcontrol/types/flow.go
@@ -41,7 +41,7 @@ type FlowKey struct {
//
// Because the `FlowKey` is immutable, changing the priority of traffic requires using a new `FlowKey`; the old flow
// instance will be automatically garbage collected by the registry when it becomes idle.
- Priority uint
+ Priority int
}
func (k FlowKey) String() string {
@@ -49,13 +49,13 @@ func (k FlowKey) String() string {
}
// Compare provides a stable comparison function for two FlowKey instances, suitable for use with sorting algorithms.
-// It returns -1 if the key is less than the other, 0 if they are equal, and 1 if the key is greater than the other.
+// It returns 1 if the key is less than the other, 0 if they are equal, and -1 if the key is greater than the other.
// The comparison is performed first by `Priority` (ascending, higher priority first) and then by `ID` (ascending).
func (k FlowKey) Compare(other FlowKey) int {
- if k.Priority < other.Priority { // Lower number means higher priority
+ if k.Priority > other.Priority { // Higher number means higher priority
return -1
}
- if k.Priority > other.Priority {
+ if k.Priority < other.Priority {
return 1
}
return strings.Compare(k.ID, other.ID)
diff --git a/pkg/epp/flowcontrol/types/mocks/mocks.go b/pkg/epp/flowcontrol/types/mocks/mocks.go
index dbef031d7..5fabf3683 100644
--- a/pkg/epp/flowcontrol/types/mocks/mocks.go
+++ b/pkg/epp/flowcontrol/types/mocks/mocks.go
@@ -19,19 +19,19 @@ limitations under the License.
package mocks
import (
- "context"
"time"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
)
// MockFlowControlRequest provides a mock implementation of the `types.FlowControlRequest` interface.
type MockFlowControlRequest struct {
- Ctx context.Context
- FlowKeyV types.FlowKey
- ByteSizeV uint64
- InitialEffectiveTTLV time.Duration
- IDV string
+ FlowKeyV types.FlowKey
+ ByteSizeV uint64
+ InitialEffectiveTTLV time.Duration
+ IDV string
+ CandidatePodsForSchedulingV []*metrics.FakePodMetrics
}
// NewMockFlowControlRequest creates a new `MockFlowControlRequest` instance.
@@ -39,25 +39,27 @@ func NewMockFlowControlRequest(
byteSize uint64,
id string,
key types.FlowKey,
- ctx context.Context,
) *MockFlowControlRequest {
- if ctx == nil {
- ctx = context.Background()
- }
return &MockFlowControlRequest{
ByteSizeV: byteSize,
IDV: id,
FlowKeyV: key,
- Ctx: ctx,
}
}
-func (m *MockFlowControlRequest) Context() context.Context { return m.Ctx }
func (m *MockFlowControlRequest) FlowKey() types.FlowKey { return m.FlowKeyV }
func (m *MockFlowControlRequest) ByteSize() uint64 { return m.ByteSizeV }
func (m *MockFlowControlRequest) InitialEffectiveTTL() time.Duration { return m.InitialEffectiveTTLV }
func (m *MockFlowControlRequest) ID() string { return m.IDV }
+func (m *MockFlowControlRequest) CandidatePodsForScheduling() []metrics.PodMetrics {
+ pods := make([]metrics.PodMetrics, 0, len(m.CandidatePodsForSchedulingV))
+ for i, pod := range m.CandidatePodsForSchedulingV {
+ pods[i] = pod
+ }
+ return pods
+}
+
var _ types.FlowControlRequest = &MockFlowControlRequest{}
// MockQueueItemHandle provides a mock implementation of the `types.QueueItemHandle` interface.
@@ -104,7 +106,6 @@ func NewMockQueueItemAccessor(byteSize uint64, reqID string, key types.FlowKey)
byteSize,
reqID,
key,
- context.Background(),
),
HandleV: &MockQueueItemHandle{},
}
diff --git a/pkg/epp/flowcontrol/types/request.go b/pkg/epp/flowcontrol/types/request.go
index 756940704..e427b0aba 100644
--- a/pkg/epp/flowcontrol/types/request.go
+++ b/pkg/epp/flowcontrol/types/request.go
@@ -17,8 +17,9 @@ limitations under the License.
package types
import (
- "context"
"time"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
)
// FlowControlRequest is the contract for an incoming request submitted to the `controller.FlowController`. It
@@ -28,11 +29,6 @@ import (
// wraps this object with its own internal structures (which implement `QueueItemAccessor`) to manage the request's
// lifecycle without modifying the original.
type FlowControlRequest interface {
- // Context returns the request's context. The `controller.FlowController` uses this for monitoring cancellation (e.g.,
- // if the client disconnects or a request-scoped timeout occurs), which can lead to the request being evicted from a
- // queue.
- Context() context.Context
-
// FlowKey returns the composite key that uniquely identifies the flow instance this request belongs to.
// The `controller.FlowController` uses this key as the primary identifier to look up the correct
// `contracts.ManagedQueue` and configured `framework.IntraFlowDispatchPolicy` from a `contracts.RegistryShard`.
@@ -49,6 +45,11 @@ type FlowControlRequest interface {
// applied.
InitialEffectiveTTL() time.Duration
+ // CandidatePodsForScheduling passes through a set of candidate pods a request may be admitted to.
+ // This is necessary for invoking `contracts.SaturationDetector.IsSaturated`, but it is otherwise unused in the Flow
+ // Control system.
+ CandidatePodsForScheduling() []metrics.PodMetrics
+
// ID returns an optional, user-facing unique identifier for this specific request. It is intended for logging,
// tracing, and observability. The `controller.FlowController` does not use this ID for dispatching decisions; it uses
// the internal, opaque `QueueItemHandle`.
@@ -92,7 +93,8 @@ type QueueItemAccessor interface {
OriginalRequest() FlowControlRequest
// EnqueueTime is the timestamp when the item was logically accepted by the `controller.FlowController` for queuing
- // (i.e., when `controller.FlowController.EnqueueAndWait()` was called).
+ // (i.e., when `controller.FlowController.EnqueueAndWait()` was called). It does not reflect the time the request
+ // landed in a `framework.SafeQueue` instance.
EnqueueTime() time.Time
// EffectiveTTL is the actual Time-To-Live assigned to this item by the `controller.FlowController`, taking into
diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go
index 7f8122195..bfc4e147a 100644
--- a/pkg/epp/handlers/request.go
+++ b/pkg/epp/handlers/request.go
@@ -17,7 +17,6 @@ limitations under the License.
package handlers
import (
- "fmt"
"strconv"
"time"
@@ -29,6 +28,13 @@ import (
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
)
+const (
+ // defaultFairnessID is the default fairness ID used when no ID is provided in the request.
+ // This ensures that requests without explicit fairness identifiers are still grouped and managed by the Flow Control
+ // system.
+ defaultFairnessID = "default-flow"
+)
+
func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error {
reqCtx.RequestReceivedTimestamp = time.Now()
@@ -42,14 +48,7 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP
if pod == nil {
return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"}
}
- pool, err := s.datastore.PoolGet()
- if err != nil {
- return err
- }
- if len(pool.Spec.TargetPorts) != 1 {
- return fmt.Errorf("expected 1 target port, got %d", len(pool.Spec.TargetPorts))
- }
- reqCtx.TargetEndpoint = pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPorts[0].Number))
+ reqCtx.TargetEndpoint = pod.GetIPAddress() + ":" + pod.GetPort()
reqCtx.RequestSize = 0
reqCtx.reqHeaderResp = s.generateRequestHeaderResponse(reqCtx)
return nil
@@ -80,6 +79,11 @@ func (s *StreamingServer) HandleRequestHeaders(reqCtx *RequestContext, req *extP
delete(reqCtx.Request.Headers, header.Key)
}
}
+
+ if reqCtx.FairnessID == "" {
+ reqCtx.FairnessID = defaultFairnessID
+ }
+
return nil
}
diff --git a/pkg/epp/handlers/request_test.go b/pkg/epp/handlers/request_test.go
index 4ae207803..a3ef90cb4 100644
--- a/pkg/epp/handlers/request_test.go
+++ b/pkg/epp/handlers/request_test.go
@@ -21,6 +21,7 @@ import (
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
+ "github.com/stretchr/testify/assert"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
)
@@ -66,3 +67,32 @@ func TestHandleRequestHeaders(t *testing.T) {
t.Errorf("expected fairness ID header to be removed from request headers, but it was not")
}
}
+
+func TestHandleRequestHeaders_DefaultFairnessID(t *testing.T) {
+ t.Parallel()
+
+ server := &StreamingServer{}
+ reqCtx := &RequestContext{
+ Request: &Request{
+ Headers: make(map[string]string),
+ },
+ }
+
+ req := &extProcPb.ProcessingRequest_RequestHeaders{
+ RequestHeaders: &extProcPb.HttpHeaders{
+ Headers: &configPb.HeaderMap{
+ Headers: []*configPb.HeaderValue{
+ {
+ Key: "x-test-header",
+ Value: "test-value",
+ },
+ },
+ },
+ EndOfStream: false,
+ },
+ }
+
+ err := server.HandleRequestHeaders(reqCtx, req)
+ assert.NoError(t, err, "expected no error")
+ assert.Equal(t, defaultFairnessID, reqCtx.FairnessID, "expected fairness ID to be defaulted")
+}
diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go
index d0c3b020a..9c2f44be5 100644
--- a/pkg/epp/handlers/response.go
+++ b/pkg/epp/handlers/response.go
@@ -19,6 +19,7 @@ package handlers
import (
"context"
"encoding/json"
+ "fmt"
"strings"
configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@@ -41,8 +42,7 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
logger := log.FromContext(ctx)
responseBytes, err := json.Marshal(response)
if err != nil {
- logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody")
- return reqCtx, err
+ return reqCtx, fmt.Errorf("error marshalling responseBody - %w", err)
}
if response["usage"] != nil {
usg := response["usage"].(map[string]any)
@@ -63,24 +63,28 @@ func (s *StreamingServer) HandleResponseBody(ctx context.Context, reqCtx *Reques
reqCtx.ResponseComplete = true
reqCtx.respBodyResp = generateResponseBodyResponses(responseBytes, true, reqCtx, logger)
- return reqCtx, nil
+
+ return s.director.HandleResponseBodyComplete(ctx, reqCtx)
}
// The function is to handle streaming response if the modelServer is streaming.
func (s *StreamingServer) HandleResponseBodyModelStreaming(ctx context.Context, reqCtx *RequestContext, responseText string) {
+ logger := log.FromContext(ctx)
+ _, err := s.director.HandleResponseBodyStreaming(ctx, reqCtx)
+ if err != nil {
+ logger.Error(err, "error in HandleResponseBodyStreaming")
+ }
if strings.Contains(responseText, streamingEndMsg) {
reqCtx.ResponseComplete = true
resp := parseRespForUsage(ctx, responseText)
reqCtx.Usage = resp.Usage
metrics.RecordInputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.PromptTokens)
metrics.RecordOutputTokens(reqCtx.IncomingModelName, reqCtx.TargetModelName, resp.Usage.CompletionTokens)
- if s.director != nil {
- s.director.HandleResponseBodyComplete(ctx, reqCtx)
+ _, err := s.director.HandleResponseBodyComplete(ctx, reqCtx)
+ if err != nil {
+ logger.Error(err, "error in HandleResponseBodyComplete")
}
}
- if s.director != nil {
- s.director.HandleResponseBodyChunk(ctx, reqCtx)
- }
}
func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *RequestContext, resp *extProcPb.ProcessingRequest_ResponseHeaders) (*RequestContext, error) {
@@ -92,7 +96,7 @@ func (s *StreamingServer) HandleResponseHeaders(ctx context.Context, reqCtx *Req
}
}
- reqCtx, err := s.director.HandleResponse(ctx, reqCtx)
+ reqCtx, err := s.director.HandleResponseReceived(ctx, reqCtx)
return reqCtx, err
}
@@ -173,16 +177,6 @@ func generateResponseBodyResponses(
continue
}
- // Add metrics to usage if present
- if usage, ok := obj["usage"].(map[string]interface{}); ok && usage != nil {
- usage["ttft_ms"] = reqCtx.TTFT
- usage["predicted_ttft_ms"] = reqCtx.PredictedTTFT
- usage["tpot_observations_ms"] = reqCtx.TPOTObservations
- usage["predicted_tpot_observations_ms"] = reqCtx.PredictedTPOTObservations
- usage["avg_tpot_ms"] = reqCtx.AvgTPOT
- usage["avg_predicted_tpot_ms"] = reqCtx.AvgPredictedTPOT
- }
-
// Re-marshal and reconstruct SSE format
if modifiedBytes, err := json.Marshal(obj); err != nil {
logger.Error(err, "failed to re-marshal modified JSON", "obj", obj)
diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go
index 6eb7734e4..63b2de0da 100644
--- a/pkg/epp/handlers/response_test.go
+++ b/pkg/epp/handlers/response_test.go
@@ -23,6 +23,7 @@ import (
"github.com/google/go-cmp/cmp"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
@@ -59,6 +60,27 @@ data: [DONE]
`
)
+type mockDirector struct{}
+
+func (m *mockDirector) HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
+ return reqCtx, nil
+}
+func (m *mockDirector) HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
+ return reqCtx, nil
+}
+func (m *mockDirector) HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
+ return reqCtx, nil
+}
+func (m *mockDirector) HandlePreRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
+ return reqCtx, nil
+}
+func (m *mockDirector) GetRandomPod() *backend.Pod {
+ return &backend.Pod{}
+}
+func (m *mockDirector) HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error) {
+ return reqCtx, nil
+}
+
func TestHandleResponseBody(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())
@@ -83,6 +105,7 @@ func TestHandleResponseBody(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := &StreamingServer{}
+ server.director = &mockDirector{}
reqCtx := test.reqCtx
if reqCtx == nil {
reqCtx = &RequestContext{}
@@ -143,6 +166,7 @@ func TestHandleStreamedResponseBody(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := &StreamingServer{}
+ server.director = &mockDirector{}
reqCtx := test.reqCtx
if reqCtx == nil {
reqCtx = &RequestContext{}
diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go
index 0e0a1d03d..274dcf32c 100644
--- a/pkg/epp/handlers/server.go
+++ b/pkg/epp/handlers/server.go
@@ -33,7 +33,6 @@ import (
v1 "sigs.k8s.io/gateway-api-inference-extension/api/v1"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
@@ -55,9 +54,9 @@ func NewStreamingServer(datastore Datastore, director Director) *StreamingServer
type Director interface {
HandleRequest(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
- HandleResponse(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
- HandleResponseBodyChunk(ctx context.Context, reqCtx *RequestContext) error
- HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) error
+ HandleResponseReceived(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
+ HandleResponseBodyStreaming(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
+ HandleResponseBodyComplete(ctx context.Context, reqCtx *RequestContext) (*RequestContext, error)
GetRandomPod() *backend.Pod
}
@@ -85,7 +84,6 @@ type RequestContext struct {
ObjectiveKey string
RequestReceivedTimestamp time.Time
ResponseCompleteTimestamp time.Time
- LastTokenTimestamp time.Time
RequestSize int
Usage Usage
ResponseSize int
@@ -93,24 +91,12 @@ type RequestContext struct {
ResponseStatusCode string
RequestRunning bool
Request *Request
- GeneratedTokenCount int
- LastSeenMetrics map[string]*backendmetrics.MetricsState
- SchedulingResult *schedulingtypes.SchedulingResult
SchedulingRequest *schedulingtypes.LLMRequest
RequestState StreamRequestState
modelServerStreaming bool
- // -- New fields for latency predictor --
- TTFT float64
- PredictedTTFT float64
- AvgTPOT float64
- AvgPredictedTPOT float64
- TokenSampler *requtil.TokenSampler
- TPOTObservations []float64
- PredictedTPOTObservations []float64
-
Response *Response
reqHeaderResp *extProcPb.ProcessingResponse
@@ -138,7 +124,7 @@ const (
HeaderRequestResponseComplete StreamRequestState = 1
BodyRequestResponsesComplete StreamRequestState = 2
TrailerRequestResponsesComplete StreamRequestState = 3
- ResponseRecieved StreamRequestState = 4
+ ResponseReceived StreamRequestState = 4
HeaderResponseResponseComplete StreamRequestState = 5
BodyResponseResponsesComplete StreamRequestState = 6
TrailerResponseResponsesComplete StreamRequestState = 7
@@ -195,9 +181,6 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
return nil
}
if recvErr != nil {
- // This error occurs very frequently, though it doesn't seem to have any impact.
- // TODO Figure out if we can remove this noise.
- logger.V(logutil.DEFAULT).Error(err, "Cannot receive stream request")
return status.Errorf(codes.Unknown, "cannot receive stream request: %v", err)
}
@@ -272,7 +255,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer)
loggerTrace.Info("model server is streaming response")
}
}
- reqCtx.RequestState = ResponseRecieved
+ reqCtx.RequestState = ResponseReceived
var responseErr error
reqCtx, responseErr = s.HandleResponseHeaders(ctx, reqCtx, v)
@@ -399,7 +382,7 @@ func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProces
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
}
}
- if r.RequestState == ResponseRecieved && r.respHeaderResp != nil {
+ if r.RequestState == ResponseReceived && r.respHeaderResp != nil {
loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp)
if err := srv.Send(r.respHeaderResp); err != nil {
return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err)
diff --git a/pkg/epp/metrics/collectors/inference_pool_test.go b/pkg/epp/metrics/collectors/inference_pool_test.go
index dcac3b37d..af2923e50 100644
--- a/pkg/epp/metrics/collectors/inference_pool_test.go
+++ b/pkg/epp/metrics/collectors/inference_pool_test.go
@@ -40,7 +40,7 @@ var (
Name: "pod1",
},
}
- pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace}
+ pod1NamespacedName = types.NamespacedName{Name: pod1.Name + "-rank-0", Namespace: pod1.Namespace}
pod1Metrics = &backendmetrics.MetricsState{
WaitingQueueSize: 100,
KVCacheUsagePercent: 0.2,
@@ -50,10 +50,10 @@ var (
func TestNoMetricsCollected(t *testing.T) {
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- datastore := datastore.NewDatastore(context.Background(), pmf)
+ ds := datastore.NewDatastore(context.Background(), pmf, 0)
collector := &inferencePoolMetricsCollector{
- ds: datastore,
+ ds: ds,
}
if err := testutil.CollectAndCompare(collector, strings.NewReader(""), ""); err != nil {
@@ -68,7 +68,7 @@ func TestMetricsCollected(t *testing.T) {
},
}
pmf := backendmetrics.NewPodMetricsFactory(pmc, time.Millisecond)
- ds := datastore.NewDatastore(context.Background(), pmf)
+ ds := datastore.NewDatastore(context.Background(), pmf, 0)
scheme := runtime.NewScheme()
fakeClient := fake.NewClientBuilder().
@@ -94,7 +94,7 @@ func TestMetricsCollected(t *testing.T) {
err := testutil.CollectAndCompare(collector, strings.NewReader(`
# HELP inference_pool_per_pod_queue_size [ALPHA] The total number of requests pending in the model server queue for each underlying pod.
# TYPE inference_pool_per_pod_queue_size gauge
- inference_pool_per_pod_queue_size{model_server_pod="pod1",name="test-pool"} 100
+ inference_pool_per_pod_queue_size{model_server_pod="pod1-rank-0",name="test-pool"} 100
`), "inference_pool_per_pod_queue_size")
if err != nil {
t.Fatal(err)
diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go
index f5910099e..e8deaaab3 100644
--- a/pkg/epp/metrics/metrics.go
+++ b/pkg/epp/metrics/metrics.go
@@ -31,34 +31,34 @@ import (
)
const (
- InferenceModelComponent = "inference_model"
- InferencePoolComponent = "inference_pool"
- InferenceExtension = "inference_extension"
+ InferenceObjectiveComponent = "inference_objective"
+ InferencePoolComponent = "inference_pool"
+ InferenceExtension = "inference_extension"
)
var (
- // Inference Model Metrics
+ // Inference Objective Metrics
requestCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_total",
- Help: metricsutil.HelpMsgWithStability("Counter of inference model requests broken out for each model and target model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Counter of inference objective requests broken out for each model and target model.", compbasemetrics.ALPHA),
},
[]string{"model_name", "target_model_name"},
)
requestErrCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_error_total",
- Help: metricsutil.HelpMsgWithStability("Counter of inference model requests errors broken out for each model and target model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Counter of inference objective requests errors broken out for each model and target model.", compbasemetrics.ALPHA),
},
[]string{"model_name", "target_model_name", "error_code"},
)
requestTTFT = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_seconds",
Help: metricsutil.HelpMsgWithStability("Inference model TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
@@ -71,7 +71,7 @@ var (
requestTTFTGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_seconds_gauge",
Help: metricsutil.HelpMsgWithStability("Inference model TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -80,7 +80,7 @@ var (
requestPredictedTTFT = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_predicted_ttft_seconds",
Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
@@ -93,7 +93,7 @@ var (
requestPredictedTTFTGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_predicted_ttft_seconds_gauge",
Help: metricsutil.HelpMsgWithStability("Inference model Predicted TTFT gauge in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -103,7 +103,7 @@ var (
// New metrics for TTFT prediction duration
requestTTFTPredictionDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_prediction_duration_seconds",
Help: metricsutil.HelpMsgWithStability("Duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
@@ -115,7 +115,7 @@ var (
requestTTFTPredictionDurationGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_prediction_duration_seconds_gauge",
Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TTFT predictions in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -124,7 +124,7 @@ var (
requestTPOT = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_seconds",
Help: metricsutil.HelpMsgWithStability("Inference model TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
@@ -137,7 +137,7 @@ var (
requestTPOTGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_seconds_gauge",
Help: metricsutil.HelpMsgWithStability("Inference model TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -145,7 +145,7 @@ var (
)
requestPredictedTPOT = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_predicted_tpot_seconds",
Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
@@ -158,7 +158,7 @@ var (
requestPredictedTPOTGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_predicted_tpot_seconds_gauge",
Help: metricsutil.HelpMsgWithStability("Inference model Predicted TPOT gauge in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -168,7 +168,7 @@ var (
// New metrics for TPOT prediction duration
requestTPOTPredictionDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_prediction_duration_seconds",
Help: metricsutil.HelpMsgWithStability("Duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
@@ -180,7 +180,7 @@ var (
requestTPOTPredictionDurationGauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_prediction_duration_seconds_gauge",
Help: metricsutil.HelpMsgWithStability("Latest duration taken to generate TPOT predictions in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -190,7 +190,7 @@ var (
// SLO Violation Metrics
requestTTFTSLOViolation = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_slo_violation",
Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TTFT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA),
},
@@ -199,7 +199,7 @@ var (
requestTTFTSLOViolationCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_slo_violation_total",
Help: metricsutil.HelpMsgWithStability("Counter of TTFT SLO violations for each model and target model.", compbasemetrics.ALPHA),
},
@@ -208,7 +208,7 @@ var (
requestTPOTSLOViolation = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_slo_violation",
Help: metricsutil.HelpMsgWithStability("Boolean indicator (0 or 1) of whether the last TPOT measurement violated the SLO threshold for each model and target model.", compbasemetrics.ALPHA),
},
@@ -217,7 +217,7 @@ var (
requestTPOTSLOViolationCounter = prometheus.NewCounterVec(
prometheus.CounterOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_slo_violation_total",
Help: metricsutil.HelpMsgWithStability("Counter of TPOT SLO violations for each model and target model.", compbasemetrics.ALPHA),
},
@@ -227,7 +227,7 @@ var (
// SLO threshold gauges (for dynamic threshold management)
requestTTFTSLOThreshold = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_ttft_slo_threshold_seconds",
Help: metricsutil.HelpMsgWithStability("Current TTFT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -236,7 +236,7 @@ var (
requestTPOTSLOThreshold = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_tpot_slo_threshold_seconds",
Help: metricsutil.HelpMsgWithStability("Current TPOT SLO threshold in seconds for each model and target model.", compbasemetrics.ALPHA),
},
@@ -245,9 +245,9 @@ var (
requestLatencies = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_duration_seconds",
- Help: metricsutil.HelpMsgWithStability("Inference model response latency distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective response latency distribution in seconds for each model and target model.", compbasemetrics.ALPHA),
Buckets: []float64{
0.005, 0.025, 0.05, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.25, 1.5, 2, 3,
4, 5, 6, 8, 10, 15, 20, 30, 45, 60, 120, 180, 240, 300, 360, 480, 600, 900, 1200, 1800, 2700, 3600,
@@ -258,9 +258,9 @@ var (
requestSizes = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "request_sizes",
- Help: metricsutil.HelpMsgWithStability("Inference model requests size distribution in bytes for each model and target model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective requests size distribution in bytes for each model and target model.", compbasemetrics.ALPHA),
// Use buckets ranging from 1000 bytes (1KB) to 10^9 bytes (1GB).
Buckets: []float64{
64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, // More fine-grained up to 64KB
@@ -273,9 +273,9 @@ var (
responseSizes = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "response_sizes",
- Help: metricsutil.HelpMsgWithStability("Inference model responses size distribution in bytes for each model and target model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective responses size distribution in bytes for each model and target model.", compbasemetrics.ALPHA),
// Most models have a response token < 8192 tokens. Each token, in average, has 4 characters.
// 8192 * 4 = 32768.
Buckets: []float64{1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32778, 65536},
@@ -285,9 +285,9 @@ var (
inputTokens = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "input_tokens",
- Help: metricsutil.HelpMsgWithStability("Inference model input token count distribution for requests in each model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective input token count distribution for requests in each model.", compbasemetrics.ALPHA),
// Most models have a input context window less than 1 million tokens.
Buckets: []float64{1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32778, 65536, 131072, 262144, 524288, 1048576},
},
@@ -296,9 +296,9 @@ var (
outputTokens = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "output_tokens",
- Help: metricsutil.HelpMsgWithStability("Inference model output token count distribution for requests in each model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective output token count distribution for requests in each model.", compbasemetrics.ALPHA),
// Most models generates output less than 8192 tokens.
Buckets: []float64{1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192},
},
@@ -307,9 +307,9 @@ var (
runningRequests = prometheus.NewGaugeVec(
prometheus.GaugeOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "running_requests",
- Help: metricsutil.HelpMsgWithStability("Inference model number of running requests in each model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective number of running requests in each model.", compbasemetrics.ALPHA),
},
[]string{"model_name"},
)
@@ -317,9 +317,9 @@ var (
// NTPOT - Normalized Time Per Output Token
NormalizedTimePerOutputToken = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
- Subsystem: InferenceModelComponent,
+ Subsystem: InferenceObjectiveComponent,
Name: "normalized_time_per_output_token_seconds",
- Help: metricsutil.HelpMsgWithStability("Inference model latency divided by number of output tokens in seconds for each model and target model.", compbasemetrics.ALPHA),
+ Help: metricsutil.HelpMsgWithStability("Inference objective latency divided by number of output tokens in seconds for each model and target model.", compbasemetrics.ALPHA),
// From few milliseconds per token to multiple seconds per token
Buckets: []float64{
0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0,
@@ -422,6 +422,28 @@ var (
},
[]string{"commit", "build_ref"},
)
+
+ // Flow Control Metrics
+ flowControlRequestQueueDuration = prometheus.NewHistogramVec(
+ prometheus.HistogramOpts{
+ Subsystem: InferenceExtension,
+ Name: "flow_control_request_queue_duration_seconds",
+ Help: metricsutil.HelpMsgWithStability("Distribution of the total time requests spend in the EPP flow control layer, measured from the start of the EnqueueAndWait call until a final outcome is reached.", compbasemetrics.ALPHA),
+ Buckets: []float64{
+ 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0,
+ },
+ },
+ []string{"fairness_id", "priority", "outcome"},
+ )
+
+ flowControlQueueSize = prometheus.NewGaugeVec(
+ prometheus.GaugeOpts{
+ Subsystem: InferenceExtension,
+ Name: "flow_control_queue_size",
+ Help: metricsutil.HelpMsgWithStability("Current number of requests being actively managed by the EPP flow control layer, from the start of the EnqueueAndWait call until a final outcome is reached.", compbasemetrics.ALPHA),
+ },
+ []string{"fairness_id", "priority"},
+ )
)
var registerMetrics sync.Once
@@ -473,7 +495,8 @@ func Register(customCollectors ...prometheus.Collector) {
metrics.Registry.MustRegister(PrefixCacheSize)
metrics.Registry.MustRegister(PrefixCacheHitRatio)
metrics.Registry.MustRegister(PrefixCacheHitLength)
-
+ metrics.Registry.MustRegister(flowControlRequestQueueDuration)
+ metrics.Registry.MustRegister(flowControlQueueSize)
for _, collector := range customCollectors {
metrics.Registry.MustRegister(collector)
}
@@ -500,6 +523,8 @@ func Reset() {
PrefixCacheSize.Reset()
PrefixCacheHitRatio.Reset()
PrefixCacheHitLength.Reset()
+ flowControlRequestQueueDuration.Reset()
+ flowControlQueueSize.Reset()
requestTPOT.Reset()
requestTTFT.Reset()
@@ -770,6 +795,21 @@ func RecordInferenceExtensionInfo(commitSha, buildRef string) {
InferenceExtensionInfo.WithLabelValues(commitSha, buildRef).Set(1)
}
+// RecordFlowControlRequestQueueDuration records the duration a request spent in the Flow Control layer.
+func RecordFlowControlRequestQueueDuration(fairnessID, priority, outcome string, duration time.Duration) {
+ flowControlRequestQueueDuration.WithLabelValues(fairnessID, priority, outcome).Observe(duration.Seconds())
+}
+
+// IncFlowControlQueueSize increments the Flow Control queue size gauge.
+func IncFlowControlQueueSize(fairnessID, priority string) {
+ flowControlQueueSize.WithLabelValues(fairnessID, priority).Inc()
+}
+
+// DecFlowControlQueueSize decrements the Flow Control queue size gauge.
+func DecFlowControlQueueSize(fairnessID, priority string) {
+ flowControlQueueSize.WithLabelValues(fairnessID, priority).Dec()
+}
+
// SetTTFTSLOThreshold sets the TTFT SLO threshold for a model.
// This allows dynamic threshold management and makes the threshold visible in metrics.
func SetTTFTSLOThreshold(modelName, targetModelName string, threshold float64) {
diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go
index f1bb23f64..736d6854c 100644
--- a/pkg/epp/metrics/metrics_test.go
+++ b/pkg/epp/metrics/metrics_test.go
@@ -22,6 +22,9 @@ import (
"testing"
"time"
+ "github.com/prometheus/client_golang/prometheus"
+ dto "github.com/prometheus/client_model/go"
+ "github.com/stretchr/testify/require"
"k8s.io/component-base/metrics/testutil"
"sigs.k8s.io/controller-runtime/pkg/metrics"
@@ -30,25 +33,32 @@ import (
)
const (
- RequestTotalMetric = InferenceModelComponent + "_request_total"
- RequestErrorTotalMetric = InferenceModelComponent + "_request_error_total"
- RequestLatenciesMetric = InferenceModelComponent + "_request_duration_seconds"
- RequestSizesMetric = InferenceModelComponent + "_request_sizes"
- ResponseSizesMetric = InferenceModelComponent + "_response_sizes"
- InputTokensMetric = InferenceModelComponent + "_input_tokens"
- OutputTokensMetric = InferenceModelComponent + "_output_tokens"
- NormalizedTimePerOutputTokenMetric = InferenceModelComponent + "_normalized_time_per_output_token_seconds"
- RunningRequestsMetric = InferenceModelComponent + "_running_requests"
+ RequestTotalMetric = InferenceObjectiveComponent + "_request_total"
+ RequestErrorTotalMetric = InferenceObjectiveComponent + "_request_error_total"
+ RequestLatenciesMetric = InferenceObjectiveComponent + "_request_duration_seconds"
+ RequestSizesMetric = InferenceObjectiveComponent + "_request_sizes"
+ ResponseSizesMetric = InferenceObjectiveComponent + "_response_sizes"
+ InputTokensMetric = InferenceObjectiveComponent + "_input_tokens"
+ OutputTokensMetric = InferenceObjectiveComponent + "_output_tokens"
+ NormalizedTimePerOutputTokenMetric = InferenceObjectiveComponent + "_normalized_time_per_output_token_seconds"
+ RunningRequestsMetric = InferenceObjectiveComponent + "_running_requests"
KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization"
QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size"
PerPodQueueSizeMetrics = InferencePoolComponent + "_per_pod_queue_size"
- RequestTTFTSecondsMetric = InferenceModelComponent + "_request_ttft_seconds"
- RequestTPOTSecondsMetric = InferenceModelComponent + "_request_tpot_seconds"
- RequestTTFTPredictionsMAPEMetric = InferenceModelComponent + "_request_ttft_predictions_mape"
- RequestTPOTPredictionsMAPEMetric = InferenceModelComponent + "_request_tpot_predictions_mape"
+ RequestTTFTSecondsMetric = InferenceObjectiveComponent + "_request_ttft_seconds"
+ RequestTPOTSecondsMetric = InferenceObjectiveComponent + "_request_tpot_seconds"
+ RequestTTFTPredictionsMAPEMetric = InferenceObjectiveComponent + "_request_ttft_predictions_mape"
+ RequestTPOTPredictionsMAPEMetric = InferenceObjectiveComponent + "_request_tpot_predictions_mape"
)
+func TestMain(m *testing.M) {
+ // Register all metrics once for the entire test suite.
+ Register()
+ os.Exit(m.Run())
+}
+
func TestRecordRequestCounterandSizes(t *testing.T) {
+ Reset()
type requests struct {
modelName string
targetModelName string
@@ -82,7 +92,6 @@ func TestRecordRequestCounterandSizes(t *testing.T) {
},
},
}}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, req := range scenario.reqs {
@@ -118,6 +127,7 @@ func TestRecordRequestCounterandSizes(t *testing.T) {
}
func TestRecordRequestErrorCounter(t *testing.T) {
+ Reset()
type requests struct {
modelName string
targetModelName string
@@ -154,7 +164,6 @@ func TestRecordRequestErrorCounter(t *testing.T) {
},
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, req := range scenario.reqs {
@@ -178,6 +187,7 @@ func TestRecordRequestErrorCounter(t *testing.T) {
}
func TestRecordRequestLatencies(t *testing.T) {
+ Reset()
ctx := logutil.NewTestLoggerIntoContext(context.Background())
timeBaseline := time.Now()
type requests struct {
@@ -233,7 +243,6 @@ func TestRecordRequestLatencies(t *testing.T) {
invalid: true,
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, req := range scenario.reqs {
@@ -260,6 +269,7 @@ func TestRecordRequestLatencies(t *testing.T) {
}
func TestRecordNormalizedTimePerOutputToken(t *testing.T) {
+ Reset()
ctx := logutil.NewTestLoggerIntoContext(context.Background())
timeBaseline := time.Now()
type tokenRequests struct {
@@ -334,7 +344,6 @@ func TestRecordNormalizedTimePerOutputToken(t *testing.T) {
invalid: true,
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, req := range scenario.reqs {
@@ -361,6 +370,7 @@ func TestRecordNormalizedTimePerOutputToken(t *testing.T) {
}
func TestRecordResponseMetrics(t *testing.T) {
+ Reset()
type responses struct {
modelName string
targetModelName string
@@ -404,7 +414,6 @@ func TestRecordResponseMetrics(t *testing.T) {
},
},
}}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, resp := range scenario.resp {
@@ -455,6 +464,7 @@ func TestRecordResponseMetrics(t *testing.T) {
}
func TestRunningRequestsMetrics(t *testing.T) {
+ Reset()
type request struct {
modelName string
complete bool // true -> request is completed, false -> running request
@@ -487,7 +497,6 @@ func TestRunningRequestsMetrics(t *testing.T) {
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, req := range scenario.requests {
@@ -515,6 +524,7 @@ func TestRunningRequestsMetrics(t *testing.T) {
}
func TestInferencePoolMetrics(t *testing.T) {
+ Reset()
scenarios := []struct {
name string
poolName string
@@ -528,7 +538,6 @@ func TestInferencePoolMetrics(t *testing.T) {
queueSizeAvg: 0.4,
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
RecordInferencePoolAvgKVCache(scenario.poolName, scenario.kvCacheAvg)
@@ -564,6 +573,7 @@ func TestInferencePoolMetrics(t *testing.T) {
}
func TestPluginProcessingLatencies(t *testing.T) {
+ Reset()
type pluginLatency struct {
extensionPoint string
pluginType string
@@ -604,7 +614,6 @@ func TestPluginProcessingLatencies(t *testing.T) {
},
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, latency := range scenario.latencies {
@@ -628,6 +637,7 @@ func TestPluginProcessingLatencies(t *testing.T) {
}
func TestSchedulerE2ELatency(t *testing.T) {
+ Reset()
scenarios := []struct {
name string
durations []time.Duration
@@ -647,7 +657,6 @@ func TestSchedulerE2ELatency(t *testing.T) {
},
},
}
- Register()
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
for _, duration := range scenario.durations {
@@ -671,6 +680,7 @@ func TestSchedulerE2ELatency(t *testing.T) {
}
func TestPrefixCacheMetrics(t *testing.T) {
+ Reset()
const (
PrefixCacheSizeMetric = InferenceExtension + "_prefix_indexer_size"
PrefixCacheHitRatioMetric = InferenceExtension + "_prefix_indexer_hit_ratio"
@@ -717,7 +727,6 @@ func TestPrefixCacheMetrics(t *testing.T) {
},
}
- Register()
t.Run(scenario.name, func(t *testing.T) {
// Record cache size metrics
for _, size := range scenario.cacheSizes {
@@ -772,3 +781,104 @@ func TestPrefixCacheMetrics(t *testing.T) {
}
})
}
+
+func getHistogramVecLabelValues(t *testing.T, h *prometheus.HistogramVec, labelValues ...string) (*dto.Histogram, error) {
+ t.Helper()
+ m, err := h.GetMetricWithLabelValues(labelValues...)
+ if err != nil {
+ return nil, err
+ }
+ metricDto := &dto.Metric{}
+ if err := m.(prometheus.Histogram).Write(metricDto); err != nil {
+ return nil, err
+ }
+ return metricDto.GetHistogram(), nil
+}
+
+func TestFlowControlQueueDurationMetric(t *testing.T) {
+ Reset()
+
+ records := []struct {
+ fairnessID string
+ priority string
+ outcome string
+ duration time.Duration
+ }{
+ {fairnessID: "user-a", priority: "100", outcome: "Dispatched", duration: 10 * time.Millisecond},
+ {fairnessID: "user-a", priority: "100", outcome: "Dispatched", duration: 20 * time.Millisecond},
+ {fairnessID: "user-b", priority: "100", outcome: "RejectedCapacity", duration: 5 * time.Millisecond},
+ {fairnessID: "user-a", priority: "50", outcome: "Dispatched", duration: 100 * time.Millisecond},
+ }
+
+ for _, rec := range records {
+ RecordFlowControlRequestQueueDuration(rec.fairnessID, rec.priority, rec.outcome, rec.duration)
+ }
+
+ testCases := []struct {
+ name string
+ labels prometheus.Labels
+ expectCount uint64
+ expectSum float64
+ }{
+ {
+ name: "user-a, prio 100, dispatched",
+ labels: prometheus.Labels{"fairness_id": "user-a", "priority": "100", "outcome": "Dispatched"},
+ expectCount: 2,
+ expectSum: 0.03, // 0.01 + 0.02
+ },
+ {
+ name: "user-b, prio 100, rejected",
+ labels: prometheus.Labels{"fairness_id": "user-b", "priority": "100", "outcome": "RejectedCapacity"},
+ expectCount: 1,
+ expectSum: 0.005,
+ },
+ {
+ name: "user-a, prio 50, dispatched",
+ labels: prometheus.Labels{"fairness_id": "user-a", "priority": "50", "outcome": "Dispatched"},
+ expectCount: 1,
+ expectSum: 0.1,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ labels := []string{tc.labels["fairness_id"], tc.labels["priority"], tc.labels["outcome"]}
+ hist, err := getHistogramVecLabelValues(t, flowControlRequestQueueDuration, labels...)
+ require.NoError(t, err, "Failed to get histogram for labels %v", tc.labels)
+ require.Equal(t, tc.expectCount, hist.GetSampleCount(), "Sample count mismatch for labels %v", tc.labels)
+ require.InDelta(t, tc.expectSum, hist.GetSampleSum(), 0.00001, "Sample sum mismatch for labels %v", tc.labels)
+ })
+ }
+}
+
+func TestFlowControlQueueSizeMetric(t *testing.T) {
+ Reset()
+
+ // Basic Inc/Dec
+ IncFlowControlQueueSize("user-a", "100")
+ val, err := testutil.GetGaugeMetricValue(flowControlQueueSize.WithLabelValues("user-a", "100"))
+ require.NoError(t, err, "Failed to get gauge value for user-a/100 after Inc")
+ require.Equal(t, 1.0, val, "Gauge value should be 1 after Inc for user-a/100")
+
+ DecFlowControlQueueSize("user-a", "100")
+ val, err = testutil.GetGaugeMetricValue(flowControlQueueSize.WithLabelValues("user-a", "100"))
+ require.NoError(t, err, "Failed to get gauge value for user-a/100 after Dec")
+ require.Equal(t, 0.0, val, "Gauge value should be 0 after Dec for user-a/100")
+
+ // Multiple labels
+ IncFlowControlQueueSize("user-b", "200")
+ IncFlowControlQueueSize("user-b", "200")
+ val, err = testutil.GetGaugeMetricValue(flowControlQueueSize.WithLabelValues("user-b", "200"))
+ require.NoError(t, err, "Failed to get gauge value for user-b/200")
+ require.Equal(t, 2.0, val, "Gauge value should be 2 for user-b/200")
+
+ DecFlowControlQueueSize("user-b", "200")
+ val, err = testutil.GetGaugeMetricValue(flowControlQueueSize.WithLabelValues("user-b", "200"))
+ require.NoError(t, err, "Failed to get gauge value for user-b/200 after one Dec")
+ require.Equal(t, 1.0, val, "Gauge value should be 1 for user-b/200 after one Dec")
+
+ // Non-existent labels
+ val, err = testutil.GetGaugeMetricValue(flowControlQueueSize.WithLabelValues("user-c", "100"))
+ require.NoError(t, err, "Failed to get gauge value for non-existent user-c/100")
+ require.Equal(t, 0.0, val, "Gauge value for non-existent labels should be 0")
+}
diff --git a/pkg/epp/metrics/testdata/input_tokens_metric b/pkg/epp/metrics/testdata/input_tokens_metric
index 245c7dfa7..5ec493f52 100644
--- a/pkg/epp/metrics/testdata/input_tokens_metric
+++ b/pkg/epp/metrics/testdata/input_tokens_metric
@@ -1,68 +1,68 @@
-# HELP inference_model_input_tokens [ALPHA] Inference model input token count distribution for requests in each model.
-# TYPE inference_model_input_tokens histogram
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="1"} 0
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="8"} 0
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="16"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="32"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="64"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="128"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="256"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="512"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="1024"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="16384"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="32778"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="65536"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="131072"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="262144"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="524288"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="1.048576e+06"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
-inference_model_input_tokens_sum{model_name="m10",target_model_name="t10"} 30
-inference_model_input_tokens_count{model_name="m10",target_model_name="t10"} 2
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="1"} 0
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="8"} 0
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="16"} 0
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="32"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="64"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="128"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="256"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="512"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="1024"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="2048"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="16384"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="32778"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="65536"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="131072"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="262144"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="524288"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="1.048576e+06"} 1
-inference_model_input_tokens_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
-inference_model_input_tokens_sum{model_name="m10",target_model_name="t11"} 30
-inference_model_input_tokens_count{model_name="m10",target_model_name="t11"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="1"} 0
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="8"} 0
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="16"} 0
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="32"} 0
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="64"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="128"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="256"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="512"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="16384"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="32778"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="65536"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="131072"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="262144"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="524288"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="1.048576e+06"} 1
-inference_model_input_tokens_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
-inference_model_input_tokens_sum{model_name="m20",target_model_name="t20"} 40
-inference_model_input_tokens_count{model_name="m20",target_model_name="t20"} 1
+# HELP inference_objective_input_tokens [ALPHA] Inference objective input token count distribution for requests in each model.
+# TYPE inference_objective_input_tokens histogram
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="1"} 0
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="8"} 0
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="16"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="32"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="64"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="128"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="256"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="512"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="1024"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="16384"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="32778"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="65536"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="131072"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="262144"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="524288"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="1.048576e+06"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
+inference_objective_input_tokens_sum{model_name="m10",target_model_name="t10"} 30
+inference_objective_input_tokens_count{model_name="m10",target_model_name="t10"} 2
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="1"} 0
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="8"} 0
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="16"} 0
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="32"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="64"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="128"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="256"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="512"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="1024"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="2048"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="16384"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="32778"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="65536"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="131072"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="262144"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="524288"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="1.048576e+06"} 1
+inference_objective_input_tokens_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
+inference_objective_input_tokens_sum{model_name="m10",target_model_name="t11"} 30
+inference_objective_input_tokens_count{model_name="m10",target_model_name="t11"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="1"} 0
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="8"} 0
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="16"} 0
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="32"} 0
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="64"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="128"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="256"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="512"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="16384"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="32778"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="65536"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="131072"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="262144"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="524288"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="1.048576e+06"} 1
+inference_objective_input_tokens_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
+inference_objective_input_tokens_sum{model_name="m20",target_model_name="t20"} 40
+inference_objective_input_tokens_count{model_name="m20",target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric b/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric
index bb6e93737..0a9c83ea4 100644
--- a/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric
+++ b/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric
@@ -1,50 +1,50 @@
-# HELP inference_model_normalized_time_per_output_token_seconds [ALPHA] Inference model latency divided by number of output tokens in seconds for each model and target model.
-# TYPE inference_model_normalized_time_per_output_token_seconds histogram
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.001"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.002"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.01"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.02"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.5"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="2.0"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="5.0"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="10.0"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="+Inf"} 2
-inference_model_normalized_time_per_output_token_seconds_sum{model_name="m10", target_model_name="t10"} 0.03
-inference_model_normalized_time_per_output_token_seconds_count{model_name="m10", target_model_name="t10"} 2
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.001"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.002"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.005"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.01"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.02"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.05"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.1"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.2"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.5"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="1.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="2.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="5.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="10.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="+Inf"} 1
-inference_model_normalized_time_per_output_token_seconds_sum{model_name="m10", target_model_name="t11"} 0.02
-inference_model_normalized_time_per_output_token_seconds_count{model_name="m10", target_model_name="t11"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.001"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.002"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.005"} 0
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.01"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.02"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.05"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.1"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.2"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.5"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="1.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="2.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="5.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="10.0"} 1
-inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="+Inf"} 1
-inference_model_normalized_time_per_output_token_seconds_sum{model_name="m20", target_model_name="t20"} 0.006
-inference_model_normalized_time_per_output_token_seconds_count{model_name="m20", target_model_name="t20"} 1
+# HELP inference_objective_normalized_time_per_output_token_seconds [ALPHA] Inference objective latency divided by number of output tokens in seconds for each model and target model.
+# TYPE inference_objective_normalized_time_per_output_token_seconds histogram
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.001"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.002"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.01"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.02"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.5"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="2.0"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="5.0"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="10.0"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="+Inf"} 2
+inference_objective_normalized_time_per_output_token_seconds_sum{model_name="m10", target_model_name="t10"} 0.03
+inference_objective_normalized_time_per_output_token_seconds_count{model_name="m10", target_model_name="t10"} 2
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.001"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.002"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.005"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.01"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.02"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.05"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.1"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.2"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.5"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="1.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="2.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="5.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="10.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="+Inf"} 1
+inference_objective_normalized_time_per_output_token_seconds_sum{model_name="m10", target_model_name="t11"} 0.02
+inference_objective_normalized_time_per_output_token_seconds_count{model_name="m10", target_model_name="t11"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.001"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.002"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.005"} 0
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.01"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.02"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.05"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.1"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.2"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.5"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="1.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="2.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="5.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="10.0"} 1
+inference_objective_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="+Inf"} 1
+inference_objective_normalized_time_per_output_token_seconds_sum{model_name="m20", target_model_name="t20"} 0.006
+inference_objective_normalized_time_per_output_token_seconds_count{model_name="m20", target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/output_tokens_metric b/pkg/epp/metrics/testdata/output_tokens_metric
index 40bbe3272..5b71ca0a3 100644
--- a/pkg/epp/metrics/testdata/output_tokens_metric
+++ b/pkg/epp/metrics/testdata/output_tokens_metric
@@ -1,47 +1,47 @@
-# HELP inference_model_output_tokens [ALPHA] Inference model output token count distribution for requests in each model.
-# TYPE inference_model_output_tokens histogram
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="1"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="8"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="16"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="32"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="64"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="128"} 1
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="256"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="512"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="1024"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
-inference_model_output_tokens_sum{model_name="m10",target_model_name="t10"} 300
-inference_model_output_tokens_count{model_name="m10",target_model_name="t10"} 2
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="1"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="8"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="16"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="32"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="64"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="128"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="256"} 0
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="512"} 1
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="1024"} 1
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="2048"} 1
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
-inference_model_output_tokens_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
-inference_model_output_tokens_sum{model_name="m10",target_model_name="t11"} 300
-inference_model_output_tokens_count{model_name="m10",target_model_name="t11"} 1
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="1"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="8"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="16"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="32"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="64"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="128"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="256"} 0
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="512"} 1
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
-inference_model_output_tokens_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
-inference_model_output_tokens_sum{model_name="m20",target_model_name="t20"} 400
-inference_model_output_tokens_count{model_name="m20",target_model_name="t20"} 1
+# HELP inference_objective_output_tokens [ALPHA] Inference objective output token count distribution for requests in each model.
+# TYPE inference_objective_output_tokens histogram
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="1"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="8"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="16"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="32"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="64"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="128"} 1
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="256"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="512"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="1024"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
+inference_objective_output_tokens_sum{model_name="m10",target_model_name="t10"} 300
+inference_objective_output_tokens_count{model_name="m10",target_model_name="t10"} 2
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="1"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="8"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="16"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="32"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="64"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="128"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="256"} 0
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="512"} 1
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="1024"} 1
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="2048"} 1
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
+inference_objective_output_tokens_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
+inference_objective_output_tokens_sum{model_name="m10",target_model_name="t11"} 300
+inference_objective_output_tokens_count{model_name="m10",target_model_name="t11"} 1
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="1"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="8"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="16"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="32"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="64"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="128"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="256"} 0
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="512"} 1
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
+inference_objective_output_tokens_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
+inference_objective_output_tokens_sum{model_name="m20",target_model_name="t20"} 400
+inference_objective_output_tokens_count{model_name="m20",target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/request_duration_seconds_metric b/pkg/epp/metrics/testdata/request_duration_seconds_metric
index 6c70b4ba9..cd6f0c061 100644
--- a/pkg/epp/metrics/testdata/request_duration_seconds_metric
+++ b/pkg/epp/metrics/testdata/request_duration_seconds_metric
@@ -1,116 +1,116 @@
-# HELP inference_model_request_duration_seconds [ALPHA] Inference model response latency distribution in seconds for each model and target model.
-# TYPE inference_model_request_duration_seconds histogram
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.025"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.4"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.6"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.8"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1.25"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1.5"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="2"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="3"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="4"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="5"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="6"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="8"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="10"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="15"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="20"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="30"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="45"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="60"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="120"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="180"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="240"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="300"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="360"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="480"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="600"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="900"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1200"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1800"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="2700"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="3600"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="Inf"} 2
-inference_model_request_duration_seconds_sum{model_name="m10", target_model_name="t10"} 1.61
-inference_model_request_duration_seconds_count{model_name="m10", target_model_name="t10"} 2
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.005"} 0
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.025"} 0
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.05"} 0
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.1"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.2"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.4"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.6"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.8"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1.25"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1.5"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="2"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="3"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="4"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="5"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="6"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="8"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="10"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="15"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="20"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="30"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="45"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="60"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="120"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="180"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="240"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="300"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="360"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="480"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="600"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="900"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1200"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1800"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="2700"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="3600"} 1
-inference_model_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
-inference_model_request_duration_seconds_sum{model_name="m10",target_model_name="t11"} 0.06
-inference_model_request_duration_seconds_count{model_name="m10",target_model_name="t11"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.005"} 0
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.025"} 0
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.05"} 0
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.1"} 0
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.2"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.4"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.6"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.8"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1.25"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1.5"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="2"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="3"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="4"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="5"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="6"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="8"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="10"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="15"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="20"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="30"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="45"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="60"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="120"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="180"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="240"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="300"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="360"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="480"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="600"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="900"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1200"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1800"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="2700"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="3600"} 1
-inference_model_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
-inference_model_request_duration_seconds_sum{model_name="m20",target_model_name="t20"} 0.12
-inference_model_request_duration_seconds_count{model_name="m20",target_model_name="t20"} 1
+# HELP inference_objective_request_duration_seconds [ALPHA] Inference objective response latency distribution in seconds for each model and target model.
+# TYPE inference_objective_request_duration_seconds histogram
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.025"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.4"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.6"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="0.8"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1.25"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1.5"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="2"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="3"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="4"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="5"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="6"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="8"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="10"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="15"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="20"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="30"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="45"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="60"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="120"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="180"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="240"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="300"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="360"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="480"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="600"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="900"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1200"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="1800"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="2700"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="3600"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10", target_model_name="t10", le="Inf"} 2
+inference_objective_request_duration_seconds_sum{model_name="m10", target_model_name="t10"} 1.61
+inference_objective_request_duration_seconds_count{model_name="m10", target_model_name="t10"} 2
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.005"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.025"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.05"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.1"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.2"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.4"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.6"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="0.8"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1.25"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1.5"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="2"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="3"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="4"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="5"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="6"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="8"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="10"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="15"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="20"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="30"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="45"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="60"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="120"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="180"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="240"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="300"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="360"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="480"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="600"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="900"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1200"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="1800"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="2700"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="3600"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
+inference_objective_request_duration_seconds_sum{model_name="m10",target_model_name="t11"} 0.06
+inference_objective_request_duration_seconds_count{model_name="m10",target_model_name="t11"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.005"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.025"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.05"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.1"} 0
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.2"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.4"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.6"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="0.8"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1.25"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1.5"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="2"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="3"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="4"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="5"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="6"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="8"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="10"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="15"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="20"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="30"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="45"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="60"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="120"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="180"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="240"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="300"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="360"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="480"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="600"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="900"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1200"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="1800"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="2700"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="3600"} 1
+inference_objective_request_duration_seconds_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
+inference_objective_request_duration_seconds_sum{model_name="m20",target_model_name="t20"} 0.12
+inference_objective_request_duration_seconds_count{model_name="m20",target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/request_error_total_metric b/pkg/epp/metrics/testdata/request_error_total_metric
index 31036eb60..2a2e55364 100644
--- a/pkg/epp/metrics/testdata/request_error_total_metric
+++ b/pkg/epp/metrics/testdata/request_error_total_metric
@@ -1,5 +1,5 @@
-# HELP inference_model_request_error_total [ALPHA] Counter of inference model requests errors broken out for each model and target model.
-# TYPE inference_model_request_error_total counter
-inference_model_request_error_total{error_code="Internal", model_name="m10",target_model_name="t10"} 2
-inference_model_request_error_total{error_code="ModelServerError", model_name="m10",target_model_name="t11"} 1
-inference_model_request_error_total{error_code="InferencePoolResourceExhausted", model_name="m20",target_model_name="t20"} 1
+# HELP inference_objective_request_error_total [ALPHA] Counter of inference objective requests errors broken out for each model and target model.
+# TYPE inference_objective_request_error_total counter
+inference_objective_request_error_total{error_code="Internal", model_name="m10",target_model_name="t10"} 2
+inference_objective_request_error_total{error_code="ModelServerError", model_name="m10",target_model_name="t11"} 1
+inference_objective_request_error_total{error_code="InferencePoolResourceExhausted", model_name="m20",target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/request_sizes_metric b/pkg/epp/metrics/testdata/request_sizes_metric
index ceca532e2..74e672591 100644
--- a/pkg/epp/metrics/testdata/request_sizes_metric
+++ b/pkg/epp/metrics/testdata/request_sizes_metric
@@ -1,86 +1,86 @@
-# HELP inference_model_request_sizes [ALPHA] Inference model requests size distribution in bytes for each model and target model.
-# TYPE inference_model_request_sizes histogram
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="64"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="128"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="256"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="512"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1024"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="16384"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="32768"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="65536"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="131072"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="262144"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="524288"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.048576e+06"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="2.097152e+06"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="4.194304e+06"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="8.388608e+06"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.6777216e+07"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="3.3554432e+07"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="6.7108864e+07"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.34217728e+08"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="2.68435456e+08"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="5.36870912e+08"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.073741824e+09"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
-inference_model_request_sizes_sum{model_name="m10",target_model_name="t10"} 1700
-inference_model_request_sizes_count{model_name="m10",target_model_name="t10"} 2
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="64"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="128"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="256"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="512"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1024"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="2048"} 0
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="16384"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="32768"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="65536"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="131072"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="262144"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="524288"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.048576e+06"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="2.097152e+06"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="4.194304e+06"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="8.388608e+06"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.6777216e+07"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="3.3554432e+07"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="6.7108864e+07"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.34217728e+08"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="2.68435456e+08"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="5.36870912e+08"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.073741824e+09"} 1
-inference_model_request_sizes_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
-inference_model_request_sizes_sum{model_name="m10",target_model_name="t11"} 2480
-inference_model_request_sizes_count{model_name="m10",target_model_name="t11"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="64"} 0
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="128"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="256"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="512"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="16384"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="32768"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="65536"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="131072"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="262144"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="524288"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.048576e+06"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="2.097152e+06"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="4.194304e+06"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="8.388608e+06"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.6777216e+07"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="3.3554432e+07"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="6.7108864e+07"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.34217728e+08"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="2.68435456e+08"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="5.36870912e+08"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.073741824e+09"} 1
-inference_model_request_sizes_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
-inference_model_request_sizes_sum{model_name="m20",target_model_name="t20"} 80
-inference_model_request_sizes_count{model_name="m20",target_model_name="t20"} 1
+# HELP inference_objective_request_sizes [ALPHA] Inference objective requests size distribution in bytes for each model and target model.
+# TYPE inference_objective_request_sizes histogram
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="64"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="128"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="256"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="512"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1024"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="16384"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="32768"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="65536"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="131072"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="262144"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="524288"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.048576e+06"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="2.097152e+06"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="4.194304e+06"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="8.388608e+06"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.6777216e+07"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="3.3554432e+07"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="6.7108864e+07"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.34217728e+08"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="2.68435456e+08"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="5.36870912e+08"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="1.073741824e+09"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
+inference_objective_request_sizes_sum{model_name="m10",target_model_name="t10"} 1700
+inference_objective_request_sizes_count{model_name="m10",target_model_name="t10"} 2
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="64"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="128"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="256"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="512"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1024"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="2048"} 0
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="16384"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="32768"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="65536"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="131072"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="262144"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="524288"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.048576e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="2.097152e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="4.194304e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="8.388608e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.6777216e+07"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="3.3554432e+07"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="6.7108864e+07"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.34217728e+08"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="2.68435456e+08"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="5.36870912e+08"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="1.073741824e+09"} 1
+inference_objective_request_sizes_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
+inference_objective_request_sizes_sum{model_name="m10",target_model_name="t11"} 2480
+inference_objective_request_sizes_count{model_name="m10",target_model_name="t11"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="64"} 0
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="128"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="256"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="512"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="16384"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="32768"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="65536"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="131072"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="262144"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="524288"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.048576e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="2.097152e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="4.194304e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="8.388608e+06"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.6777216e+07"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="3.3554432e+07"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="6.7108864e+07"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.34217728e+08"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="2.68435456e+08"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="5.36870912e+08"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="1.073741824e+09"} 1
+inference_objective_request_sizes_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
+inference_objective_request_sizes_sum{model_name="m20",target_model_name="t20"} 80
+inference_objective_request_sizes_count{model_name="m20",target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/request_total_metric b/pkg/epp/metrics/testdata/request_total_metric
index 9c6f48a36..a6200fdc9 100644
--- a/pkg/epp/metrics/testdata/request_total_metric
+++ b/pkg/epp/metrics/testdata/request_total_metric
@@ -1,5 +1,5 @@
-# HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model.
-# TYPE inference_model_request_total counter
-inference_model_request_total{model_name="m10", target_model_name="t10"} 2
-inference_model_request_total{model_name="m10", target_model_name="t11"} 1
-inference_model_request_total{model_name="m20", target_model_name="t20"} 1
+# HELP inference_objective_request_total [ALPHA] Counter of inference objective requests broken out for each model and target model.
+# TYPE inference_objective_request_total counter
+inference_objective_request_total{model_name="m10", target_model_name="t10"} 2
+inference_objective_request_total{model_name="m10", target_model_name="t11"} 1
+inference_objective_request_total{model_name="m20", target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/response_sizes_metric b/pkg/epp/metrics/testdata/response_sizes_metric
index 7f981090c..a9ad76ecb 100644
--- a/pkg/epp/metrics/testdata/response_sizes_metric
+++ b/pkg/epp/metrics/testdata/response_sizes_metric
@@ -1,56 +1,56 @@
-# HELP inference_model_response_sizes [ALPHA] Inference model responses size distribution in bytes for each model and target model.
-# TYPE inference_model_response_sizes histogram
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="1"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="8"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="16"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="32"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="64"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="128"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="256"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="512"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="1024"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="16384"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="32778"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="65536"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
-inference_model_response_sizes_sum{model_name="m10",target_model_name="t10"} 1700
-inference_model_response_sizes_count{model_name="m10",target_model_name="t10"} 2
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="1"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="8"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="16"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="32"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="64"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="128"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="256"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="512"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="1024"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="2048"} 0
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="16384"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="32778"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="65536"} 1
-inference_model_response_sizes_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
-inference_model_response_sizes_sum{model_name="m10",target_model_name="t11"} 2480
-inference_model_response_sizes_count{model_name="m10",target_model_name="t11"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="1"} 0
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="8"} 0
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="16"} 0
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="32"} 0
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="64"} 0
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="128"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="256"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="512"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="16384"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="32778"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="65536"} 1
-inference_model_response_sizes_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
-inference_model_response_sizes_sum{model_name="m20",target_model_name="t20"} 80
-inference_model_response_sizes_count{model_name="m20",target_model_name="t20"} 1
+# HELP inference_objective_response_sizes [ALPHA] Inference objective responses size distribution in bytes for each model and target model.
+# TYPE inference_objective_response_sizes histogram
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="1"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="8"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="16"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="32"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="64"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="128"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="256"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="512"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="1024"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="2048"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="4096"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="8192"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="16384"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="32778"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="65536"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t10",le="+Inf"} 2
+inference_objective_response_sizes_sum{model_name="m10",target_model_name="t10"} 1700
+inference_objective_response_sizes_count{model_name="m10",target_model_name="t10"} 2
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="1"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="8"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="16"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="32"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="64"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="128"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="256"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="512"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="1024"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="2048"} 0
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="4096"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="8192"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="16384"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="32778"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="65536"} 1
+inference_objective_response_sizes_bucket{model_name="m10",target_model_name="t11",le="+Inf"} 1
+inference_objective_response_sizes_sum{model_name="m10",target_model_name="t11"} 2480
+inference_objective_response_sizes_count{model_name="m10",target_model_name="t11"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="1"} 0
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="8"} 0
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="16"} 0
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="32"} 0
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="64"} 0
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="128"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="256"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="512"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="1024"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="2048"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="4096"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="8192"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="16384"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="32778"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="65536"} 1
+inference_objective_response_sizes_bucket{model_name="m20",target_model_name="t20",le="+Inf"} 1
+inference_objective_response_sizes_sum{model_name="m20",target_model_name="t20"} 80
+inference_objective_response_sizes_count{model_name="m20",target_model_name="t20"} 1
diff --git a/pkg/epp/metrics/testdata/running_requests_metrics b/pkg/epp/metrics/testdata/running_requests_metrics
index a880e4998..962a50fbf 100644
--- a/pkg/epp/metrics/testdata/running_requests_metrics
+++ b/pkg/epp/metrics/testdata/running_requests_metrics
@@ -1,4 +1,4 @@
-# HELP inference_model_running_requests [ALPHA] Inference model number of running requests in each model.
-# TYPE inference_model_running_requests gauge
-inference_model_running_requests{model_name="m1"} 1
-inference_model_running_requests{model_name="m2"} 1
+# HELP inference_objective_running_requests [ALPHA] Inference objective number of running requests in each model.
+# TYPE inference_objective_running_requests gauge
+inference_objective_running_requests{model_name="m1"} 1
+inference_objective_running_requests{model_name="m2"} 1
diff --git a/pkg/epp/plugins/handle.go b/pkg/epp/plugins/handle.go
index 8c9153cf1..c074e9076 100644
--- a/pkg/epp/plugins/handle.go
+++ b/pkg/epp/plugins/handle.go
@@ -19,6 +19,8 @@ package plugins
import (
"context"
"fmt"
+
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
)
// Handle provides plugins a set of standard data and tools to work with
@@ -27,6 +29,9 @@ type Handle interface {
Context() context.Context
HandlePlugins
+
+ // PodList lists pods matching the given predicate.
+ PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
}
// HandlePlugins defines a set of APIs to work with instantiated plugins
@@ -44,10 +49,14 @@ type HandlePlugins interface {
GetAllPluginsWithNames() map[string]Plugin
}
+// PodListFunc is a function type that filters and returns a list of pod metrics
+type PodListFunc func(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
+
// eppHandle is an implementation of the interface plugins.Handle
type eppHandle struct {
ctx context.Context
HandlePlugins
+ podList PodListFunc
}
// Context returns a context the plugins can use, if they need one
@@ -84,12 +93,18 @@ func (h *eppHandlePlugins) GetAllPluginsWithNames() map[string]Plugin {
return h.plugins
}
-func NewEppHandle(ctx context.Context) Handle {
+// PodList lists pods matching the given predicate.
+func (h *eppHandle) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics {
+ return h.podList(predicate)
+}
+
+func NewEppHandle(ctx context.Context, podList PodListFunc) Handle {
return &eppHandle{
ctx: ctx,
HandlePlugins: &eppHandlePlugins{
plugins: map[string]Plugin{},
},
+ podList: podList,
}
}
diff --git a/pkg/epp/requestcontrol/admission.go b/pkg/epp/requestcontrol/admission.go
new file mode 100644
index 000000000..69fd5adf8
--- /dev/null
+++ b/pkg/epp/requestcontrol/admission.go
@@ -0,0 +1,216 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package requestcontrol
+
+import (
+ "context"
+ "time"
+
+ "sigs.k8s.io/controller-runtime/pkg/log"
+
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
+ errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
+)
+
+// AdmissionController defines the interface for making admission control decisions.
+// Implementations of this interface determine whether an incoming inference request should be accepted or rejected
+// based on various criteria such as system load, fairness, priority, and available capacity.
+type AdmissionController interface {
+ // Admit determines if a request should be admitted.
+ // It is called by the Director for each incoming request.
+ //
+ // Args:
+ // ctx: The request context, carrying deadlines, cancellation signals, and logger.
+ // reqCtx: The handlers.RequestContext containing details about the incoming request.
+ // candidatePods: A list of potential backend pods that can serve the request.
+ // priority: The priority level of the request, as determined by the InferenceObjective.
+ //
+ // Returns:
+ // - nil: If the request is admitted and should proceed to scheduling.
+ // - errutil.Error: If the request is rejected.
+ Admit(
+ ctx context.Context,
+ reqCtx *handlers.RequestContext,
+ candidatePods []backendmetrics.PodMetrics,
+ priority int,
+ ) error
+}
+
+// saturationDetector defines the minimal interface required for checking if the backend pool is saturated.
+type saturationDetector interface {
+ IsSaturated(ctx context.Context, candidatePods []backendmetrics.PodMetrics) bool
+}
+
+// flowController defines the minimal interface required by FlowControlAdmissionController for enqueuing requests and
+// waiting for an admission outcome.
+type flowController interface {
+ EnqueueAndWait(ctx context.Context, req types.FlowControlRequest) (types.QueueOutcome, error)
+}
+
+// rejectIfSheddableAndSaturated checks if a request should be immediately rejected because it's sheddable
+// (priority < 0) and the system is saturated.
+func rejectIfSheddableAndSaturated(
+ ctx context.Context,
+ sd saturationDetector,
+ reqCtx *handlers.RequestContext,
+ candidatePods []backendmetrics.PodMetrics,
+ priority int,
+) error {
+ if requtil.IsSheddable(priority) {
+ logger := log.FromContext(ctx)
+ if sd.IsSaturated(ctx, candidatePods) {
+ logger.V(logutil.TRACE).Info("Request rejected: system saturated and request is sheddable",
+ "requestID", reqCtx.SchedulingRequest.RequestId)
+ return errutil.Error{
+ Code: errutil.InferencePoolResourceExhausted,
+ Msg: "system saturated, sheddable request dropped",
+ }
+ }
+ }
+ return nil
+}
+
+// --- LegacyAdmissionController ---
+
+// LegacyAdmissionController implements saturation-based admission control.
+// It rejects sheddable requests (priority < 0) if the saturationDetector indicates that the system is currently
+// saturated. Non-sheddable requests always bypass the saturation check.
+type LegacyAdmissionController struct {
+ saturationDetector saturationDetector
+}
+
+// NewLegacyAdmissionController creates a new LegacyAdmissionController.
+func NewLegacyAdmissionController(sd saturationDetector) *LegacyAdmissionController {
+ return &LegacyAdmissionController{saturationDetector: sd}
+}
+
+// Admit implements the AdmissionController interface for the legacy strategy.
+// It checks for saturation only for requests with priority < 0.
+func (lac *LegacyAdmissionController) Admit(
+ ctx context.Context,
+ reqCtx *handlers.RequestContext,
+ candidatePods []backendmetrics.PodMetrics,
+ priority int,
+) error {
+ logger := log.FromContext(ctx)
+ logger.V(logutil.TRACE).Info("Executing LegacyAdmissionController",
+ "priority", priority, "fairnessID", reqCtx.FairnessID)
+ if err := rejectIfSheddableAndSaturated(ctx, lac.saturationDetector, reqCtx, candidatePods, priority); err != nil {
+ return err
+ }
+ logger.V(logutil.TRACE).Info("Request admitted", "requestID", reqCtx.SchedulingRequest.RequestId)
+ return nil
+}
+
+// --- FlowControlAdmissionController ---
+
+// FlowControlAdmissionController delegates admission decisions to the Flow Control layer.
+// It first checks if the request is sheddable and the system is saturated, rejecting immediately if both conditions are
+// true. Otherwise, it uses the provided flowController to enqueue the request and await an outcome.
+type FlowControlAdmissionController struct {
+ saturationDetector saturationDetector
+ flowController flowController
+}
+
+// NewFlowControlAdmissionController creates a new FlowControlAdmissionController.
+// It requires a SaturationDetector and a flowController instance.
+func NewFlowControlAdmissionController(sd saturationDetector, fc flowController) *FlowControlAdmissionController {
+ return &FlowControlAdmissionController{
+ saturationDetector: sd,
+ flowController: fc,
+ }
+}
+
+// Admit implements the AdmissionController interface by checking for saturation on sheddable requests first, then
+// deferring to the Flow Control system.
+func (fcac *FlowControlAdmissionController) Admit(
+ ctx context.Context,
+ reqCtx *handlers.RequestContext,
+ candidatePods []backendmetrics.PodMetrics,
+ priority int,
+) error {
+ logger := log.FromContext(ctx)
+ logger.V(logutil.TRACE).Info("Executing FlowControlAdmissionController",
+ "requestID", reqCtx.SchedulingRequest.RequestId, "priority", priority, "fairnessID", reqCtx.FairnessID)
+ if err := rejectIfSheddableAndSaturated(ctx, fcac.saturationDetector, reqCtx, candidatePods, priority); err != nil {
+ return err
+ }
+
+ logger.V(logutil.TRACE).Info("Request proceeding to flow control", "requestID", reqCtx.SchedulingRequest.RequestId)
+
+ fcReq := &flowControlRequest{
+ requestID: reqCtx.SchedulingRequest.RequestId,
+ fairnessID: reqCtx.FairnessID,
+ priority: priority,
+ requestByteSize: uint64(reqCtx.RequestSize),
+ candidatePods: candidatePods,
+ }
+
+ outcome, err := fcac.flowController.EnqueueAndWait(ctx, fcReq)
+ logger.V(logutil.DEBUG).Info("Flow control outcome",
+ "requestID", reqCtx.SchedulingRequest.RequestId, "outcome", outcome, "error", err)
+ return translateFlowControlOutcome(outcome, err)
+}
+
+// flowControlRequest is an adapter that implements the types.FlowControlRequest interface.
+type flowControlRequest struct {
+ requestID string
+ fairnessID string
+ priority int
+ requestByteSize uint64
+ candidatePods []backendmetrics.PodMetrics
+}
+
+var _ types.FlowControlRequest = &flowControlRequest{}
+
+func (r *flowControlRequest) ID() string { return r.requestID }
+func (r *flowControlRequest) InitialEffectiveTTL() time.Duration { return 0 } // Use controller default.
+func (r *flowControlRequest) ByteSize() uint64 { return r.requestByteSize }
+func (r *flowControlRequest) CandidatePodsForScheduling() []backendmetrics.PodMetrics {
+ return r.candidatePods
+}
+func (r *flowControlRequest) FlowKey() types.FlowKey {
+ return types.FlowKey{ID: r.fairnessID, Priority: r.priority}
+}
+
+// translateFlowControlOutcome maps the context-rich outcome of the Flow Control layer to the public errutil.Error
+// contract used by the Director.
+func translateFlowControlOutcome(outcome types.QueueOutcome, err error) error {
+ msg := "request rejected by flow control"
+ if err != nil {
+ msg = err.Error()
+ }
+
+ switch outcome {
+ case types.QueueOutcomeDispatched:
+ return nil
+ case types.QueueOutcomeRejectedCapacity:
+ return errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: msg}
+ case types.QueueOutcomeEvictedTTL:
+ return errutil.Error{Code: errutil.ServiceUnavailable, Msg: "request timed out in queue: " + msg}
+ case types.QueueOutcomeEvictedContextCancelled:
+ return errutil.Error{Code: errutil.ServiceUnavailable, Msg: "client disconnected: " + msg}
+ case types.QueueOutcomeRejectedOther, types.QueueOutcomeEvictedOther:
+ return errutil.Error{Code: errutil.Internal, Msg: "internal flow control error: " + msg}
+ default:
+ return errutil.Error{Code: errutil.Internal, Msg: "unhandled flow control outcome: " + msg}
+ }
+}
diff --git a/pkg/epp/requestcontrol/admission_test.go b/pkg/epp/requestcontrol/admission_test.go
new file mode 100644
index 000000000..085778200
--- /dev/null
+++ b/pkg/epp/requestcontrol/admission_test.go
@@ -0,0 +1,282 @@
+/*
+Copyright 2025 The Kubernetes Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package requestcontrol
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
+ fctypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
+ schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+)
+
+// --- Mocks ---
+
+type mockSaturationDetector struct {
+ isSaturated bool
+}
+
+func (m *mockSaturationDetector) IsSaturated(_ context.Context, _ []backendmetrics.PodMetrics) bool {
+ return m.isSaturated
+}
+
+type mockFlowController struct {
+ outcome fctypes.QueueOutcome
+ err error
+ called bool
+}
+
+func (m *mockFlowController) EnqueueAndWait(
+ _ context.Context,
+ _ fctypes.FlowControlRequest,
+) (fctypes.QueueOutcome, error) {
+ m.called = true
+ return m.outcome, m.err
+}
+
+func TestLegacyAdmissionController_Admit(t *testing.T) {
+ t.Parallel()
+ ctx := logutil.NewTestLoggerIntoContext(context.Background())
+ candidatePods := []backendmetrics.PodMetrics{}
+ reqCtx := &handlers.RequestContext{
+ SchedulingRequest: &schedulingtypes.LLMRequest{RequestId: "test-req"},
+ }
+
+ testCases := []struct {
+ name string
+ priority int
+ isSaturated bool
+ expectErr bool
+ expectErrCode string
+ expectErrSubstr string
+ }{
+ {
+ name: "non_sheddable_saturated_admit",
+ priority: 0,
+ isSaturated: true,
+ expectErr: false,
+ },
+ {
+ name: "sheddable_not_saturated_admit",
+ priority: -1,
+ isSaturated: false,
+ expectErr: false,
+ },
+ {
+ name: "sheddable_saturated_reject",
+ priority: -1,
+ isSaturated: true,
+ expectErr: true,
+ expectErrCode: errutil.InferencePoolResourceExhausted,
+ expectErrSubstr: "system saturated, sheddable request dropped",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ saturationDetector := &mockSaturationDetector{isSaturated: tc.isSaturated}
+ ac := NewLegacyAdmissionController(saturationDetector)
+
+ err := ac.Admit(ctx, reqCtx, candidatePods, tc.priority)
+
+ if !tc.expectErr {
+ assert.NoError(t, err, "Admit() should not have returned an error for scenario: %s", tc.name)
+ } else {
+ require.Error(t, err, "Admit() should have returned an error for scenario: %s", tc.name)
+ var e errutil.Error
+ if assert.ErrorAs(t, err, &e, "error should be of type errutil.Error") {
+ assert.Equal(t, tc.expectErrCode, e.Code, "incorrect error code for scenario: %s", tc.name)
+ assert.Contains(t, e.Msg, tc.expectErrSubstr, "incorrect error message substring for scenario: %s", tc.name)
+ }
+ }
+ })
+ }
+}
+
+func TestFlowControlRequestAdapter(t *testing.T) {
+ t.Parallel()
+ candidatePods := []backendmetrics.PodMetrics{&backendmetrics.FakePodMetrics{}}
+
+ testCases := []struct {
+ name string
+ requestID string
+ fairnessID string
+ priority int
+ requestByteSize uint64
+ expectFlowKey fctypes.FlowKey
+ }{
+ {
+ name: "simple",
+ requestID: "req-1",
+ fairnessID: "flow-1",
+ priority: 10,
+ requestByteSize: 1024,
+ expectFlowKey: fctypes.FlowKey{ID: "flow-1", Priority: 10},
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ fcReq := &flowControlRequest{
+ requestID: tc.requestID,
+ fairnessID: tc.fairnessID,
+ priority: tc.priority,
+ requestByteSize: tc.requestByteSize,
+ candidatePods: candidatePods,
+ }
+
+ assert.Equal(t, tc.requestID, fcReq.ID(), "ID() mismatch")
+ assert.Equal(t, tc.requestByteSize, fcReq.ByteSize(), "ByteSize() mismatch")
+ assert.Equal(t, candidatePods, fcReq.CandidatePodsForScheduling(), "CandidatePodsForScheduling() mismatch")
+ assert.Equal(t, tc.expectFlowKey, fcReq.FlowKey(), "FlowKey() mismatch")
+ assert.Zero(t, fcReq.InitialEffectiveTTL(), "InitialEffectiveTTL() should be zero")
+ })
+ }
+}
+func TestFlowControlAdmissionController_Admit(t *testing.T) {
+ t.Parallel()
+ ctx := logutil.NewTestLoggerIntoContext(context.Background())
+ candidatePods := []backendmetrics.PodMetrics{}
+
+ reqCtx := &handlers.RequestContext{
+ SchedulingRequest: &schedulingtypes.LLMRequest{RequestId: "test-req"},
+ }
+
+ testCases := []struct {
+ name string
+ priority int
+ isSaturated bool
+ fcOutcome fctypes.QueueOutcome
+ fcErr error
+ expectErr bool
+ expectErrCode string
+ expectErrSubstr string
+ expectFCSkipped bool
+ }{
+ {
+ name: "sheddable_saturated_reject",
+ priority: -1,
+ isSaturated: true,
+ expectErr: true,
+ expectErrCode: errutil.InferencePoolResourceExhausted,
+ expectErrSubstr: "system saturated, sheddable request dropped",
+ expectFCSkipped: true,
+ },
+ {
+ name: "sheddable_not_saturated_dispatch",
+ priority: -1,
+ isSaturated: false,
+ fcOutcome: fctypes.QueueOutcomeDispatched,
+ expectErr: false,
+ },
+ {
+ name: "non_sheddable_saturated_dispatch",
+ priority: 0,
+ isSaturated: true,
+ fcOutcome: fctypes.QueueOutcomeDispatched,
+ expectErr: false,
+ },
+ {
+ name: "fc_reject_capacity",
+ priority: 0,
+ fcOutcome: fctypes.QueueOutcomeRejectedCapacity,
+ expectErr: true,
+ expectErrCode: errutil.InferencePoolResourceExhausted,
+ expectErrSubstr: "request rejected by flow control",
+ },
+ {
+ name: "fc_evict_ttl",
+ priority: 0,
+ fcOutcome: fctypes.QueueOutcomeEvictedTTL,
+ fcErr: errors.New("timeout"),
+ expectErr: true,
+ expectErrCode: errutil.ServiceUnavailable,
+ expectErrSubstr: "request timed out in queue: timeout",
+ },
+ {
+ name: "fc_evict_context_cancelled",
+ priority: 0,
+ fcOutcome: fctypes.QueueOutcomeEvictedContextCancelled,
+ expectErr: true,
+ expectErrCode: errutil.ServiceUnavailable,
+ expectErrSubstr: "client disconnected",
+ },
+ {
+ name: "fc_reject_other",
+ priority: 0,
+ fcOutcome: fctypes.QueueOutcomeRejectedOther,
+ expectErr: true,
+ expectErrCode: errutil.Internal,
+ expectErrSubstr: "internal flow control error",
+ },
+ {
+ name: "fc_evict_other",
+ priority: 0,
+ fcOutcome: fctypes.QueueOutcomeEvictedOther,
+ fcErr: errors.New("internal error"),
+ expectErr: true,
+ expectErrCode: errutil.Internal,
+ expectErrSubstr: "internal flow control error: internal error",
+ },
+ {
+ name: "fc_unhandled_outcome",
+ priority: 0,
+ fcOutcome: fctypes.QueueOutcomeNotYetFinalized,
+ expectErr: true,
+ expectErrCode: errutil.Internal,
+ expectErrSubstr: "unhandled flow control outcome",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ t.Parallel()
+ sd := &mockSaturationDetector{isSaturated: tc.isSaturated}
+ fc := &mockFlowController{outcome: tc.fcOutcome, err: tc.fcErr}
+ ac := NewFlowControlAdmissionController(sd, fc)
+
+ err := ac.Admit(ctx, reqCtx, candidatePods, tc.priority)
+
+ if tc.expectFCSkipped {
+ assert.False(t, fc.called, "FlowController should not have been called for scenario: %s", tc.name)
+ } else {
+ assert.True(t, fc.called, "FlowController should have been called for scenario: %s", tc.name)
+ }
+
+ if !tc.expectErr {
+ assert.NoError(t, err, "Admit() returned an unexpected error for scenario: %s", tc.name)
+ } else {
+ require.Error(t, err, "Admit() should have returned an error for scenario: %s", tc.name)
+ var e errutil.Error
+ if assert.ErrorAs(t, err, &e, "error should be of type errutil.Error") {
+ assert.Equal(t, tc.expectErrCode, e.Code, "incorrect error code for scenario: %s", tc.name)
+ assert.Contains(t, e.Msg, tc.expectErrSubstr, "incorrect error message substring for scenario: %s", tc.name)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go
index 662491bdb..feda7edca 100644
--- a/pkg/epp/requestcontrol/director.go
+++ b/pkg/epp/requestcontrol/director.go
@@ -33,7 +33,6 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metadata"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
@@ -50,55 +49,6 @@ type Datastore interface {
PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics
}
-/*
-NOTE: To support this refined logic, the `handlers.RequestContext` struct
-(defined in a different package) would need to be updated as follows:
-
-type RequestContext struct {
- // ... existing fields ...
- RequestReceivedTimestamp time.Time
- FirstTokenTimestamp time.Time
- ResponseCompleteTimestamp time.Time
- IsModelServerStreaming func() bool
- ResponseComplete bool
- Prompt string
- LastSeenMetrics *backend.Metrics
- // ... etc ...
-
- // -- New fields for latency predictor --
- PredictedTTFT float64 // The predicted TTFT in milliseconds
- PredictedTPOT float64 // The predicted TPOT in milliseconds
- TTFT float64 // Actual Time To First Token in milliseconds
- LastTokenTimestamp time.Time // Timestamp of the last token received
- TPOTObservations []float64 // All actual inter-token latencies (for which we have predictions)
- PredictedTPOTObservations []float64 // Predicted inter-token latencies (only for sampled tokens)
- GeneratedTokenCount int // Current number of tokens generated
-}
-
-*/
-
-const (
- subsetHintNamespace = "envoy.lb.subset_hint"
- subsetHintKey = "x-gateway-destination-endpoint-subset"
-)
-
-const (
- // Poisson sampling parameters for predictions
- defaultSamplingMean = 100 // Mean interval between prediction samples (tokens)
- maxSampledTokens = 20 // Maximum number of prediction samples per request
-)
-
-// calculateRunningAverage calculates the running average efficiently
-func calculateRunningAverage(currentAvg float64, newValue float64, count int) float64 {
- if count == 0 {
- return 0
- }
- if count == 1 {
- return newValue
- }
- return currentAvg + (newValue-currentAvg)/float64(count)
-}
-
// parseFloatHeader retrieves a header by name, parses it as a float64,
// and returns the value or an error if the header is missing or invalid.
func parseFloatHeader(reqCtx *handlers.RequestContext, headerName string) (float64, bool, error) {
@@ -148,46 +98,43 @@ type Scheduler interface {
Schedule(ctx context.Context, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) (result *schedulingtypes.SchedulingResult, err error)
}
-// SaturationDetector provides a signal indicating whether the backends are considered saturated.
-type SaturationDetector interface {
- IsSaturated(ctx context.Context, candidatePods []backendmetrics.PodMetrics) bool
-}
-
// NewDirectorWithConfig creates a new Director instance with all dependencies.
-func NewDirectorWithConfig(datastore datastore.Datastore, scheduler Scheduler, saturationDetector SaturationDetector, config *Config) *Director {
+func NewDirectorWithConfig(
+ datastore Datastore,
+ scheduler Scheduler,
+ admissionController AdmissionController,
+ config *Config,
+) *Director {
return &Director{
- datastore: datastore,
- scheduler: scheduler,
- saturationDetector: saturationDetector,
- preRequestPlugins: config.preRequestPlugins,
- postResponsePlugins: config.postResponsePlugins,
- postResponseChunkPlugins: config.postResponseChunkPlugins,
- postResponseCompletePlugins: config.postResponseCompletePlugins,
- defaultPriority: 0, // define default priority explicitly
+ datastore: datastore,
+ scheduler: scheduler,
+ admissionController: admissionController,
+ requestControlPlugins: *config,
+ defaultPriority: 0, // define default priority explicitly
}
}
-// Director orchestrates the request handling flow, including scheduling.
+// Director orchestrates the request handling flow after initial parsing by the handler.
+// Its responsibilities include:
+// - Retrieving request metadata and relevant objectives.
+// - Determining candidate pods.
+// - Performing admission control via the AdmissionController.
+// - Scheduling the request to target pod(s) via the Scheduler.
+// - Running PreRequest plugins.
+// - Preparing the request context for the Envoy ext_proc filter to route the request.
+// - Running PostResponse plugins.
type Director struct {
- datastore datastore.Datastore
- scheduler Scheduler
- saturationDetector SaturationDetector
- preRequestPlugins []PreRequest
- postResponsePlugins []PostResponse
- postResponseChunkPlugins []PostResponseChunk
- postResponseCompletePlugins []PostResponseComplete
+ datastore Datastore
+ scheduler Scheduler
+ admissionController AdmissionController
+ requestControlPlugins Config
// we just need a pointer to an int variable since priority is a pointer in InferenceObjective
// no need to set this in the constructor, since the value we want is the default int val
// and value types cannot be nil
defaultPriority int
}
-// HandleRequest orchestrates the request lifecycle:
-// 1. Parses request details.
-// 2. Calls admitRequest for admission control.
-// 3. Calls Scheduler.Schedule if request is approved.
-// 4. Calls prepareRequest to populate RequestContext with result and call PreRequest plugins.
-//
+// HandleRequest orchestrates the request lifecycle.
// It always returns the requestContext even in the error case, as the request context is used in error handling.
func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx)
@@ -206,20 +153,17 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
}
reqCtx.Request.Body["model"] = reqCtx.TargetModelName
- prompt, err := requtil.ExtractPromptFromRequestBody(requestBodyMap)
+ requestBody, err := requtil.ExtractRequestBody(reqCtx.Request.Body)
if err != nil {
- return reqCtx, err
+ return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Errorf("failed to extract request data: %w", err).Error()}
}
+
infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey)
if infObjective == nil {
logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey)
- priority := d.defaultPriority
- if strings.Contains(reqCtx.ObjectiveKey, "sheddable") {
- priority = -1
- }
infObjective = &v1alpha2.InferenceObjective{
Spec: v1alpha2.InferenceObjectiveSpec{
- Priority: &priority,
+ Priority: &d.defaultPriority,
},
}
} else if infObjective.Spec.Priority == nil {
@@ -247,7 +191,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
reqCtx.SchedulingRequest = &schedulingtypes.LLMRequest{
RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
TargetModel: reqCtx.TargetModelName,
- Prompt: prompt,
+ Body: requestBody,
Headers: reqCtx.Request.Headers,
TTFTSLO: ttftSLO,
AvgTPOTSLO: avgTPOTSLO,
@@ -266,25 +210,17 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
return reqCtx, errutil.Error{Code: errutil.ServiceUnavailable, Msg: "failed to find candidate pods for serving the request"}
}
- // TODO
- // 1. Create datastore request object
- // 2. Read/Write and maybe Drop to it during Schedule() and admitRequest()
- // 3. Add it to the scheduled pod's RequestPriorityQueue
- // 4. Drop from pod's RequestPriorityQueue and datastore global map when request is fully processed
-
- //
+ if err := d.admissionController.Admit(ctx, reqCtx, candidatePods, *infObjective.Spec.Priority); err != nil {
+ logger.V(logutil.DEFAULT).Info("Request rejected by admission control", "error", err)
+ return reqCtx, err
+ }
result, err := d.scheduler.Schedule(ctx, reqCtx.SchedulingRequest, d.toSchedulerPodMetrics(candidatePods))
if err != nil {
return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()}
}
- // Admission Control check
- if err := d.admitRequest(ctx, candidatePods, reqCtx.SchedulingRequest, *infObjective.Spec.Priority, reqCtx.FairnessID); err != nil {
- return reqCtx, err
- }
-
- // --- 4. Prepare Request (Populates RequestContext and call PreRequest plugins) ---
+ // Prepare Request (Populates RequestContext and call PreRequest plugins)
// Insert target endpoint to instruct Envoy to route requests to the specified target pod and attach the port number.
// Invoke PreRequest registered plugins.
reqCtx, err = d.prepareRequest(ctx, reqCtx, result)
@@ -295,33 +231,6 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
return reqCtx, nil
}
-// admitRequest handles admission control to decide whether or not to accept the request
-// based on the request priority and system saturation state.
-func (d *Director) admitRequest(ctx context.Context, candidatePods []backendmetrics.PodMetrics, request *schedulingtypes.LLMRequest, requestPriority int, fairnessID string) error {
- logger := log.FromContext(ctx)
-
- logger.V(logutil.DEBUG).Info("Entering Flow Control", "priority", requestPriority, "fairnessID", fairnessID)
-
- // This will be removed in favor of a more robust implementation (Flow Control) in the very near future.
- // TODO: Make this a configurable value.
- // Tracking issue https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1347
- if requestPriority >= 0 {
- logger.V(logutil.DEBUG).Info("Non-sheddable request bypassing saturation check.")
- return nil
- } else {
- logger.V(logutil.DEBUG).Info("Sheddable request subject to saturation check.")
- }
-
- if d.saturationDetector.IsSaturated(ctx, candidatePods) || !request.HasValidPod { // Assuming non-nil Saturation Detector
- return errutil.Error{
- Code: errutil.InferencePoolResourceExhausted,
- Msg: "system saturated, sheddable request dropped",
- }
- }
-
- return nil
-}
-
// getCandidatePodsForScheduling gets the list of relevant endpoints for the scheduling cycle from the datastore.
// according to EPP protocol, if "x-gateway-destination-endpoint-subset" is set on the request metadata and specifies
// a subset of endpoints, only these endpoints will be considered as candidates for the scheduler.
@@ -357,7 +266,7 @@ func (d *Director) getCandidatePodsForScheduling(ctx context.Context, requestMet
podTotalCount := 0
podFilteredList := d.datastore.PodList(func(pm backendmetrics.PodMetrics) bool {
podTotalCount++
- if _, found := endpoints[pm.GetPod().Address]; found {
+ if _, found := endpoints[pm.GetPod().GetIPAddress()]; found {
return true
}
return false
@@ -376,20 +285,12 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
return reqCtx, errutil.Error{Code: errutil.Internal, Msg: "results must be greater than zero"}
}
// primary profile is used to set destination
- pool, err := d.datastore.PoolGet()
- if err != nil {
- return reqCtx, err
- }
targetPods := []*backend.Pod{}
- if len(pool.Spec.TargetPorts) != 1 {
- return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "targetPorts should have length 1"}
- }
- targetPort := int(pool.Spec.TargetPorts[0].Number)
targetEndpoints := []string{}
for _, pod := range result.ProfileResults[result.PrimaryProfileName].TargetPods {
curPod := pod.GetPod()
- curEndpoint := net.JoinHostPort(curPod.Address, strconv.Itoa(targetPort))
+ curEndpoint := net.JoinHostPort(curPod.GetIPAddress(), curPod.GetPort())
targetPods = append(targetPods, curPod)
targetEndpoints = append(targetEndpoints, curEndpoint)
}
@@ -400,10 +301,7 @@ func (d *Director) prepareRequest(ctx context.Context, reqCtx *handlers.RequestC
reqCtx.TargetPod = targetPods[0]
reqCtx.TargetEndpoint = multiEndpointString
- d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result, targetPort)
- reqCtx.SchedulingResult = result
- reqCtx.LastSeenMetrics = make(map[string]*backendmetrics.MetricsState)
- RefreshLastSeenMetrics(ctx, reqCtx)
+ d.runPreRequestPlugins(ctx, reqCtx.SchedulingRequest, result)
return reqCtx, nil
}
@@ -417,36 +315,47 @@ func (d *Director) toSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []sch
return pm
}
-// HandleResponseHeaders is called when the first chunk of the response arrives.
-func (d *Director) HandleResponse(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
- logger := log.FromContext(ctx).WithValues("stage", "headers")
- logger.V(logutil.DEBUG).Info("Entering HandleResponseHeaders")
+// HandleResponseReceived is called when the response headers are received.
+func (d *Director) HandleResponseReceived(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
+ response := &Response{
+ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
+ Headers: reqCtx.Response.Headers,
+ }
- d.runPostResponsePlugins(ctx, reqCtx)
+ // TODO: to extend fallback functionality, handle cases where target pod is unavailable
+ // https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1224
+ d.runResponseReceivedPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
- logger.V(logutil.DEBUG).Info("Exiting HandleResponseHeaders")
return reqCtx, nil
}
-func (d *Director) HandleResponseBodyChunk(ctx context.Context, reqCtx *handlers.RequestContext) error {
+// HandleResponseBodyStreaming is called every time a chunk of the response body is received.
+func (d *Director) HandleResponseBodyStreaming(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.TRACE).Info("Entering HandleResponseBodyChunk")
+ response := &Response{
+ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
+ Headers: reqCtx.Response.Headers,
+ }
- d.runPostResponseChunkPlugins(ctx, reqCtx)
+ d.runResponseStreamingPlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
logger.V(logutil.TRACE).Info("Exiting HandleResponseBodyChunk")
- return nil
+ return reqCtx, nil
}
// HandleResponseBodyComplete is called when the response body is fully received.
-// It runs the PostResponseComplete plugins.
-func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) error {
+func (d *Director) HandleResponseBodyComplete(ctx context.Context, reqCtx *handlers.RequestContext) (*handlers.RequestContext, error) {
logger := log.FromContext(ctx).WithValues("stage", "bodyChunk")
logger.V(logutil.DEBUG).Info("Entering HandleResponseBodyComplete")
+ response := &Response{
+ RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
+ Headers: reqCtx.Response.Headers,
+ }
- d.runPostResponseCompletePlugins(ctx, reqCtx)
+ d.runResponseCompletePlugins(ctx, reqCtx.SchedulingRequest, response, reqCtx.TargetPod)
logger.V(logutil.DEBUG).Info("Exiting HandleResponseBodyComplete")
- return nil
+ return reqCtx, nil
}
func (d *Director) GetRandomPod() *backend.Pod {
@@ -460,45 +369,46 @@ func (d *Director) GetRandomPod() *backend.Pod {
}
func (d *Director) runPreRequestPlugins(ctx context.Context, request *schedulingtypes.LLMRequest,
- schedulingResult *schedulingtypes.SchedulingResult, targetPort int) {
+ schedulingResult *schedulingtypes.SchedulingResult) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
- for _, plugin := range d.preRequestPlugins {
- loggerDebug.Info("Running pre-request plugin", "plugin", plugin.TypedName())
+ for _, plugin := range d.requestControlPlugins.preRequestPlugins {
+ loggerDebug.Info("Running PreRequest plugin", "plugin", plugin.TypedName())
before := time.Now()
- plugin.PreRequest(ctx, request, schedulingResult, targetPort)
+ plugin.PreRequest(ctx, request, schedulingResult)
metrics.RecordPluginProcessingLatency(PreRequestExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
- loggerDebug.Info("Completed running pre-request plugin successfully", "plugin", plugin.TypedName())
+ loggerDebug.Info("Completed running PreRequest plugin successfully", "plugin", plugin.TypedName())
}
}
-func (d *Director) runPostResponsePlugins(ctx context.Context, reqCtx *handlers.RequestContext) {
+func (d *Director) runResponseReceivedPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
- for _, plugin := range d.postResponsePlugins {
- loggerDebug.Info("Running post-response plugin", "plugin", plugin.TypedName())
+ for _, plugin := range d.requestControlPlugins.responseReceivedPlugins {
+ loggerDebug.Info("Running ResponseReceived plugin", "plugin", plugin.TypedName())
before := time.Now()
- plugin.PostResponse(ctx, reqCtx)
- metrics.RecordPluginProcessingLatency(PostResponseExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
- loggerDebug.Info("Completed running post-response plugin successfully", "plugin", plugin.TypedName())
+ plugin.ResponseReceived(ctx, request, response, targetPod)
+ metrics.RecordPluginProcessingLatency(ResponseReceivedExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
+ loggerDebug.Info("Completed running ResponseReceived plugin successfully", "plugin", plugin.TypedName())
}
}
-func (d *Director) runPostResponseChunkPlugins(ctx context.Context, reqCtx *handlers.RequestContext) {
+func (d *Director) runResponseStreamingPlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerTrace := log.FromContext(ctx).V(logutil.TRACE)
- for _, plugin := range d.postResponseChunkPlugins {
- loggerTrace.Info("Running post-response chunk plugin", "plugin", plugin.TypedName().Type)
+ for _, plugin := range d.requestControlPlugins.responseStreamingPlugins {
+ loggerTrace.Info("Running ResponseStreaming plugin", "plugin", plugin.TypedName())
before := time.Now()
- plugin.PostResponseChunk(ctx, reqCtx)
- metrics.RecordPluginProcessingLatency(PostResponseChunkExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
+ plugin.ResponseStreaming(ctx, request, response, targetPod)
+ metrics.RecordPluginProcessingLatency(ResponseStreamingExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
+ loggerTrace.Info("Completed running ResponseStreaming plugin successfully", "plugin", plugin.TypedName())
}
}
-func (d *Director) runPostResponseCompletePlugins(ctx context.Context, reqCtx *handlers.RequestContext) {
+func (d *Director) runResponseCompletePlugins(ctx context.Context, request *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
- for _, plugin := range d.postResponseCompletePlugins {
- loggerDebug.Info("Running post-response complete plugin", "plugin", plugin.TypedName().Type)
+ for _, plugin := range d.requestControlPlugins.responseCompletePlugins {
+ loggerDebug.Info("Running ResponseComplete plugin", "plugin", plugin.TypedName())
before := time.Now()
- plugin.PostResponseComplete(ctx, reqCtx)
- metrics.RecordPluginProcessingLatency(PostResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
- loggerDebug.Info("Completed running post-response complete plugin successfully", "plugin", plugin.TypedName())
+ plugin.ResponseComplete(ctx, request, response, targetPod)
+ metrics.RecordPluginProcessingLatency(ResponseCompleteExtensionPoint, plugin.TypedName().Type, plugin.TypedName().Name, time.Since(before))
+ loggerDebug.Info("Completed running ResponseComplete plugin successfully", "plugin", plugin.TypedName())
}
}
diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go
index 61a8b31be..c3111ed5e 100644
--- a/pkg/epp/requestcontrol/director_test.go
+++ b/pkg/epp/requestcontrol/director_test.go
@@ -40,7 +40,6 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/apix/v1alpha2"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync"
@@ -55,12 +54,17 @@ import (
// --- Mocks ---
-type mockSaturationDetector struct {
- isSaturated bool
+type mockAdmissionController struct {
+ admitErr error
}
-func (m *mockSaturationDetector) IsSaturated(_ context.Context, _ []backendmetrics.PodMetrics) bool {
- return m.isSaturated
+func (m *mockAdmissionController) Admit(
+ _ context.Context,
+ _ *handlers.RequestContext,
+ _ []backendmetrics.PodMetrics,
+ _ int,
+) error {
+ return m.admitErr
}
// Updated mock scheduler to handle the new Schedule method signature
@@ -143,27 +147,7 @@ func (ds *mockDatastore) PodList(predicate func(backendmetrics.PodMetrics) bool)
return res
}
-func (ds *mockDatastore) PodDelete(namespacedName types.NamespacedName) {}
-func (ds *mockDatastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { return true }
-func (ds *mockDatastore) ObjectiveSet(infObjective *v1alpha2.InferenceObjective) {}
-func (ds *mockDatastore) ObjectiveDelete(namespacedName types.NamespacedName) {}
-func (ds *mockDatastore) ObjectiveGetAll() []*v1alpha2.InferenceObjective { return nil }
-func (ds *mockDatastore) PodAddRequest(podName types.NamespacedName, requestID string, tpot float64) error {
- return nil
-}
-func (ds *mockDatastore) PodRemoveRequest(podName types.NamespacedName, requestID string) error {
- return nil
-}
-func (ds *mockDatastore) PodUpdateRequest(podName types.NamespacedName, requestID string, tpot float64) error {
- return nil
-}
-func (ds *mockDatastore) PodGetRunningRequests(podName types.NamespacedName) (*datalayer.RequestPriorityQueue, error) {
- return nil, nil
-}
-func (ds *mockDatastore) PodGetRequestCount(podName types.NamespacedName) (int, error) { return 0, nil }
-func (ds *mockDatastore) Clear() {}
-// mockPredictor implements the Predictor interface for testing.
type mockPredictor struct {
PredictFunc func(ctx context.Context, req latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error)
trainingSamples []latencypredictor.TrainingEntry
@@ -201,7 +185,6 @@ func (m *mockPredictor) AddTrainingDataBulk(entry []latencypredictor.TrainingEnt
m.trainingSamples = append(m.trainingSamples, entry...)
return nil
}
-
func TestDirector_HandleRequest(t *testing.T) {
ctx := logutil.NewTestLoggerIntoContext(context.Background())
@@ -229,7 +212,7 @@ func TestDirector_HandleRequest(t *testing.T) {
// Datastore setup
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second)
- ds := datastore.NewDatastore(t.Context(), pmf)
+ ds := datastore.NewDatastore(t.Context(), pmf, 0)
ds.ObjectiveSet(ioFoodReview)
ds.ObjectiveSet(ioFoodReviewResolve)
ds.ObjectiveSet(ioFoodReviewSheddable)
@@ -279,6 +262,8 @@ func TestDirector_HandleRequest(t *testing.T) {
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
NamespacedName: types.NamespacedName{Name: "pod1", Namespace: "default"},
},
},
@@ -287,6 +272,8 @@ func TestDirector_HandleRequest(t *testing.T) {
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
Address: "192.168.2.100",
+ Port: "8000",
+ MetricsHost: "192.168.2.100:8000",
NamespacedName: types.NamespacedName{Name: "pod2", Namespace: "default"},
},
},
@@ -295,6 +282,8 @@ func TestDirector_HandleRequest(t *testing.T) {
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
Address: "192.168.4.100",
+ Port: "8000",
+ MetricsHost: "192.168.4.100:8000",
NamespacedName: types.NamespacedName{Name: "pod4", Namespace: "default"},
},
},
@@ -310,10 +299,9 @@ func TestDirector_HandleRequest(t *testing.T) {
&schedulingtypes.ScoredPod{
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
- Address: "192.168.1.100",
- NamespacedName: types.NamespacedName{Name: "pod1", Namespace: "default"},
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Add empty queue
- Labels: map[string]string{"app": "inference"},
+ Address: "192.168.1.100",
+ NamespacedName: types.NamespacedName{Name: "pod1", Namespace: "default"},
+ Labels: map[string]string{"app": "inference"},
},
},
},
@@ -323,10 +311,9 @@ func TestDirector_HandleRequest(t *testing.T) {
&schedulingtypes.ScoredPod{
Pod: &schedulingtypes.PodMetrics{
Pod: &backend.Pod{
- Address: "192.168.1.100",
- NamespacedName: types.NamespacedName{Name: "pod1", Namespace: "default"},
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Add empty queue
- Labels: map[string]string{"app": "inference"},
+ Address: "192.168.1.100",
+ NamespacedName: types.NamespacedName{Name: "pod1", Namespace: "default"},
+ Labels: map[string]string{"app": "inference"},
},
},
}: 0.8, // 80% prefix cache score
@@ -337,24 +324,24 @@ func TestDirector_HandleRequest(t *testing.T) {
}
tests := []struct {
- name string
- reqBodyMap map[string]any
- mockSaturationDetector *mockSaturationDetector
- inferenceObjectiveName string
- schedulerMockSetup func(m *mockScheduler)
- predictorMockSetup func(m *mockPredictor) // NEW: Add predictor setup
- wantErrCode string // Expected errutil code string
- wantReqCtx *handlers.RequestContext // Fields to check in the returned RequestContext
- wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch
- targetModelName string // Expected model name after target model resolution
+ name string
+ reqBodyMap map[string]any
+ mockAdmissionController *mockAdmissionController
+ inferenceObjectiveName string
+ schedulerMockSetup func(m *mockScheduler)
+ predictorMockSetup func(m *mockPredictor)
+ wantErrCode string // Expected errutil code string
+ wantReqCtx *handlers.RequestContext // Fields to check in the returned RequestContext
+ wantMutatedBodyModel string // Expected model in reqCtx.Request.Body after PostDispatch
+ targetModelName string // Expected model name after target model resolution
}{
{
- name: "successful completions request (critical, saturation ignored)",
+ name: "successful completions request",
reqBodyMap: map[string]any{
"model": model,
"prompt": "critical prompt",
},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: true},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = defaultSuccessfulScheduleResults
},
@@ -362,9 +349,10 @@ func TestDirector_HandleRequest(t *testing.T) {
ObjectiveKey: objectiveName,
TargetModelName: model,
TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
+ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
+ Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -373,39 +361,7 @@ func TestDirector_HandleRequest(t *testing.T) {
targetModelName: model,
},
{
- name: "non-critical request dropped due to saturation",
- reqBodyMap: map[string]any{
- "model": modelSheddable,
- "prompt": "test prompt",
- },
- mockSaturationDetector: &mockSaturationDetector{isSaturated: true},
- schedulerMockSetup: func(m *mockScheduler) {
- m.scheduleResults = defaultSuccessfulScheduleResults
- },
- wantReqCtx: &handlers.RequestContext{
- ObjectiveKey: objectiveNameSheddable,
- TargetModelName: model,
- TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
- },
- TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
- },
- predictorMockSetup: func(m *mockPredictor) {
- // Mock prediction that violates SLOs
- m.PredictFunc = func(ctx context.Context, req latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error) {
- return &latencypredictor.PredictionResponse{
- TTFT: 150.0, // Above SLO of 100
- TPOT: 80.0, // Above SLO of 50
- }, nil
- }
- },
- inferenceObjectiveName: objectiveNameSheddable,
- wantErrCode: errutil.InferencePoolResourceExhausted,
- },
- {
- name: "successful chat completions request (default critical, saturation ignored)",
+ name: "successful chat completions request",
reqBodyMap: map[string]any{
"model": model,
"messages": []any{
@@ -415,7 +371,7 @@ func TestDirector_HandleRequest(t *testing.T) {
},
},
},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: true},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = defaultSuccessfulScheduleResults
},
@@ -424,6 +380,8 @@ func TestDirector_HandleRequest(t *testing.T) {
TargetPod: &backend.Pod{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -436,7 +394,7 @@ func TestDirector_HandleRequest(t *testing.T) {
"model": model, // Critical model
"prompt": "test prompt",
},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: true},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = defaultSuccessfulScheduleResults
},
@@ -452,9 +410,10 @@ func TestDirector_HandleRequest(t *testing.T) {
wantReqCtx: &handlers.RequestContext{
TargetModelName: model,
TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
+ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
+ Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -476,6 +435,7 @@ func TestDirector_HandleRequest(t *testing.T) {
},
},
},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = defaultSuccessfulScheduleResults
},
@@ -483,9 +443,10 @@ func TestDirector_HandleRequest(t *testing.T) {
ObjectiveKey: objectiveName,
TargetModelName: model,
TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
+ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
+ Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -499,7 +460,7 @@ func TestDirector_HandleRequest(t *testing.T) {
"model": modelSheddable,
"prompt": "sheddable prompt",
},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: false},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = defaultSuccessfulScheduleResults
},
@@ -507,9 +468,10 @@ func TestDirector_HandleRequest(t *testing.T) {
ObjectiveKey: objectiveNameSheddable,
TargetModelName: modelSheddable,
TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
+ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
+ Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -523,7 +485,7 @@ func TestDirector_HandleRequest(t *testing.T) {
"model": modelWithResolvedTarget,
"prompt": "prompt for target resolution",
},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: false},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = defaultSuccessfulScheduleResults
},
@@ -531,9 +493,10 @@ func TestDirector_HandleRequest(t *testing.T) {
ObjectiveKey: objectiveNameResolve,
TargetModelName: "resolved-target-model-A",
TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
+ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
+ Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -550,9 +513,10 @@ func TestDirector_HandleRequest(t *testing.T) {
ObjectiveKey: "food-review-1",
TargetModelName: "food-review-1",
TargetPod: &backend.Pod{
- NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
- Address: "192.168.1.100",
- RunningRequests: &datalayer.RequestPriorityQueue{}, // Empty but initialized
+ NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
+ Address: "192.168.1.100",
+ Port: "8000",
+ MetricsHost: "192.168.1.100:8000",
},
TargetEndpoint: "192.168.1.100:8000,192.168.2.100:8000,192.168.4.100:8000",
},
@@ -561,28 +525,27 @@ func TestDirector_HandleRequest(t *testing.T) {
"model": "food-review-1",
"prompt": "test prompt",
},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: false},
- inferenceObjectiveName: "food-review-1",
- targetModelName: "food-review-1",
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
+ inferenceObjectiveName: "food-review-1",
+ targetModelName: "food-review-1",
},
{
- name: "request dropped (sheddable, saturated)",
+ name: "request rejected by admission controller",
reqBodyMap: map[string]any{
"model": modelSheddable,
"prompt": "sheddable prompt",
},
- inferenceObjectiveName: objectiveNameSheddable,
- mockSaturationDetector: &mockSaturationDetector{isSaturated: true},
- wantErrCode: errutil.InferencePoolResourceExhausted,
+ inferenceObjectiveName: objectiveNameSheddable,
+ mockAdmissionController: &mockAdmissionController{admitErr: errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: "simulated admission rejection"}},
+ wantErrCode: errutil.InferencePoolResourceExhausted,
},
{
- name: "model not found, expect err",
- reqBodyMap: map[string]any{"prompt": "p"},
- mockSaturationDetector: &mockSaturationDetector{isSaturated: false},
- wantErrCode: errutil.BadRequest,
+ name: "model not found, expect err",
+ reqBodyMap: map[string]any{"prompt": "p"},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
+ wantErrCode: errutil.BadRequest,
},
-
{
name: "prompt or messages not found, expect err",
reqBodyMap: map[string]any{"model": model},
@@ -602,6 +565,7 @@ func TestDirector_HandleRequest(t *testing.T) {
"model": model,
"prompt": "prompt that causes scheduler error",
},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleErr = errors.New("simulated scheduler failure")
},
@@ -614,6 +578,7 @@ func TestDirector_HandleRequest(t *testing.T) {
"model": model,
"prompt": "prompt for nil,nil scheduler return",
},
+ mockAdmissionController: &mockAdmissionController{admitErr: nil},
schedulerMockSetup: func(m *mockScheduler) {
m.scheduleResults = nil
m.scheduleErr = nil
@@ -636,9 +601,9 @@ func TestDirector_HandleRequest(t *testing.T) {
if test.predictorMockSetup != nil {
mockPred = &mockPredictor{}
test.predictorMockSetup(mockPred)
- director = NewDirectorWithConfig(ds, mockSched, test.mockSaturationDetector, NewConfig())
+ director = NewDirectorWithConfig(ds, mockSched, test.mockAdmissionController, NewConfig())
} else {
- director = NewDirectorWithConfig(ds, mockSched, test.mockSaturationDetector, NewConfig())
+ director = NewDirectorWithConfig(ds, mockSched, test.mockAdmissionController, NewConfig())
}
reqCtx := &handlers.RequestContext{
@@ -674,15 +639,7 @@ func TestDirector_HandleRequest(t *testing.T) {
assert.Equal(t, test.wantReqCtx.ObjectiveKey, returnedReqCtx.ObjectiveKey, "reqCtx.Model mismatch")
assert.Equal(t, test.wantReqCtx.TargetModelName, returnedReqCtx.TargetModelName,
"reqCtx.ResolvedTargetModel mismatch")
- if test.wantReqCtx != nil && test.wantReqCtx.TargetPod != nil {
- expected := test.wantReqCtx.TargetPod
- actual := returnedReqCtx.TargetPod
-
- assert.Equal(t, expected.NamespacedName, actual.NamespacedName, "NamespacedName mismatch")
- assert.Equal(t, expected.Address, actual.Address, "Address mismatch")
- assert.Equal(t, expected.Labels, actual.Labels, "Labels mismatch")
- // Skip RunningRequests comparison - it's not relevant to the test
- }
+ assert.Equal(t, test.wantReqCtx.TargetPod, returnedReqCtx.TargetPod, "reqCtx.TargetPod mismatch")
assert.Equal(t, test.wantReqCtx.TargetEndpoint, returnedReqCtx.TargetEndpoint, "reqCtx.TargetEndpoint mismatch")
}
@@ -766,13 +723,13 @@ func TestGetCandidatePodsForScheduling(t *testing.T) {
ds := &mockDatastore{pods: testInput}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- director := NewDirectorWithConfig(ds, &mockScheduler{}, &mockSaturationDetector{}, NewConfig())
+ director := NewDirectorWithConfig(ds, &mockScheduler{}, &mockAdmissionController{}, NewConfig())
got := director.getCandidatePodsForScheduling(context.Background(), test.metadata)
diff := cmp.Diff(test.output, got, cmpopts.SortSlices(func(a, b backendmetrics.PodMetrics) bool {
return a.GetPod().NamespacedName.String() < b.GetPod().NamespacedName.String()
- }), cmpopts.IgnoreUnexported(backendmetrics.FakePodMetrics{}))
+ }))
if diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
@@ -809,10 +766,29 @@ func TestGetRandomPod(t *testing.T) {
},
}
+ scheme := runtime.NewScheme()
+ _ = clientgoscheme.AddToScheme(scheme)
+ _ = v1alpha2.Install(scheme)
+ _ = v1.Install(scheme)
+ fakeClient := fake.NewClientBuilder().
+ WithScheme(scheme).
+ Build()
+ pool := &v1.InferencePool{
+ Spec: v1.InferencePoolSpec{
+ TargetPorts: []v1.Port{
+ {Number: 8000},
+ },
+ },
+ }
+
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Millisecond)
- ds := datastore.NewDatastore(t.Context(), pmf)
+ ds := datastore.NewDatastore(t.Context(), pmf, 0)
+ err := ds.PoolSet(t.Context(), fakeClient, pool)
+ if err != nil {
+ t.Errorf("unexpected error setting pool: %s", err)
+ }
for _, pod := range test.storePods {
ds.PodUpdateOrAddIfNotExist(pod)
}
@@ -829,13 +805,13 @@ func TestGetRandomPod(t *testing.T) {
}
}
-func TestDirector_HandleResponse(t *testing.T) {
- pr1 := newTestPostResponse("pr1")
+func TestDirector_HandleResponseReceived(t *testing.T) {
+ pr1 := newTestResponseReceived("pr1")
ctx := logutil.NewTestLoggerIntoContext(context.Background())
- ds := datastore.NewDatastore(t.Context(), nil)
+ ds := datastore.NewDatastore(t.Context(), nil, 0)
mockSched := &mockScheduler{}
- director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithPostResponsePlugins(pr1))
+ director := NewDirectorWithConfig(ds, mockSched, &mockAdmissionController{}, NewConfig().WithResponseReceivedPlugins(pr1))
reqCtx := &handlers.RequestContext{
Request: &handlers.Request{
@@ -850,7 +826,7 @@ func TestDirector_HandleResponse(t *testing.T) {
TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
}
- _, err := director.HandleResponse(ctx, reqCtx)
+ _, err := director.HandleResponseReceived(ctx, reqCtx)
if err != nil {
t.Fatalf("HandleResponse() returned unexpected error: %v", err)
}
@@ -866,31 +842,143 @@ func TestDirector_HandleResponse(t *testing.T) {
}
}
+func TestDirector_HandleResponseStreaming(t *testing.T) {
+ ps1 := newTestResponseStreaming("ps1")
+
+ ctx := logutil.NewTestLoggerIntoContext(context.Background())
+ ds := datastore.NewDatastore(t.Context(), nil, 0)
+ mockSched := &mockScheduler{}
+ director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseStreamingPlugins(ps1))
+
+ reqCtx := &handlers.RequestContext{
+ Request: &handlers.Request{
+ Headers: map[string]string{
+ requtil.RequestIdHeaderKey: "test-req-id-for-streaming",
+ },
+ },
+ Response: &handlers.Response{
+ Headers: map[string]string{"X-Test-Streaming-Header": "StreamValue"},
+ },
+ TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
+ }
+
+ _, err := director.HandleResponseBodyStreaming(ctx, reqCtx)
+ if err != nil {
+ t.Fatalf("HandleResponseBodyStreaming() returned unexpected error: %v", err)
+ }
+
+ if diff := cmp.Diff("test-req-id-for-streaming", ps1.lastRespOnStreaming.RequestId); diff != "" {
+ t.Errorf("Scheduler.OnStreaming RequestId mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(reqCtx.Response.Headers, ps1.lastRespOnStreaming.Headers); diff != "" {
+ t.Errorf("Scheduler.OnStreaming Headers mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff("namespace1/test-pod-name", ps1.lastTargetPodOnStreaming); diff != "" {
+ t.Errorf("Scheduler.OnStreaming TargetPodName mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestDirector_HandleResponseComplete(t *testing.T) {
+ pc1 := newTestResponseComplete("pc1")
+
+ ctx := logutil.NewTestLoggerIntoContext(context.Background())
+ ds := datastore.NewDatastore(t.Context(), nil, 0)
+ mockSched := &mockScheduler{}
+ director := NewDirectorWithConfig(ds, mockSched, nil, NewConfig().WithResponseCompletePlugins(pc1))
+
+ reqCtx := &handlers.RequestContext{
+ Request: &handlers.Request{
+ Headers: map[string]string{
+ requtil.RequestIdHeaderKey: "test-req-id-for-complete",
+ },
+ },
+ Response: &handlers.Response{
+ Headers: map[string]string{"X-Test-Complete-Header": "CompleteValue"},
+ },
+ TargetPod: &backend.Pod{NamespacedName: types.NamespacedName{Namespace: "namespace1", Name: "test-pod-name"}},
+ }
+
+ _, err := director.HandleResponseBodyComplete(ctx, reqCtx)
+ if err != nil {
+ t.Fatalf("HandleResponseBodyComplete() returned unexpected error: %v", err)
+ }
+
+ if diff := cmp.Diff("test-req-id-for-complete", pc1.lastRespOnComplete.RequestId); diff != "" {
+ t.Errorf("Scheduler.OnComplete RequestId mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff(reqCtx.Response.Headers, pc1.lastRespOnComplete.Headers); diff != "" {
+ t.Errorf("Scheduler.OnComplete Headers mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff("namespace1/test-pod-name", pc1.lastTargetPodOnComplete); diff != "" {
+ t.Errorf("Scheduler.OnComplete TargetPodName mismatch (-want +got):\n%s", diff)
+ }
+}
+
const (
- testPostResponseType = "test-post-response"
+ testResponseReceivedType = "test-response-received"
+ testPostStreamingType = "test-response-streaming"
+ testPostCompleteType = "test-response-complete"
)
-type testPostResponse struct {
+type testResponseReceived struct {
tn plugins.TypedName
lastRespOnResponse *Response
lastTargetPodOnResponse string
}
-func newTestPostResponse(name string) *testPostResponse {
- return &testPostResponse{
- tn: plugins.TypedName{Type: testPostResponseType, Name: name},
+type testResponseStreaming struct {
+ tn plugins.TypedName
+ lastRespOnStreaming *Response
+ lastTargetPodOnStreaming string
+}
+
+type testResponseComplete struct {
+ tn plugins.TypedName
+ lastRespOnComplete *Response
+ lastTargetPodOnComplete string
+}
+
+func newTestResponseReceived(name string) *testResponseReceived {
+ return &testResponseReceived{
+ tn: plugins.TypedName{Type: testResponseReceivedType, Name: name},
}
}
-func (p *testPostResponse) TypedName() plugins.TypedName {
- return p.tn
+func newTestResponseStreaming(name string) *testResponseStreaming {
+ return &testResponseStreaming{
+ tn: plugins.TypedName{Type: testPostStreamingType, Name: name},
+ }
}
-func (p *testPostResponse) PostResponse(_ context.Context, reqCtx *handlers.RequestContext) {
- response := &Response{
- RequestId: reqCtx.Request.Headers[requtil.RequestIdHeaderKey],
- Headers: reqCtx.Response.Headers,
+func newTestResponseComplete(name string) *testResponseComplete {
+ return &testResponseComplete{
+ tn: plugins.TypedName{Type: testPostCompleteType, Name: name},
}
+}
+
+func (p *testResponseReceived) TypedName() plugins.TypedName {
+ return p.tn
+}
+
+func (p *testResponseStreaming) TypedName() plugins.TypedName {
+ return p.tn
+}
+
+func (p *testResponseComplete) TypedName() plugins.TypedName {
+ return p.tn
+}
+
+func (p *testResponseReceived) ResponseReceived(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
p.lastRespOnResponse = response
- p.lastTargetPodOnResponse = reqCtx.TargetPod.NamespacedName.String()
+ p.lastTargetPodOnResponse = targetPod.NamespacedName.String()
+}
+
+func (p *testResponseStreaming) ResponseStreaming(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
+ p.lastRespOnStreaming = response
+ p.lastTargetPodOnStreaming = targetPod.NamespacedName.String()
+}
+
+func (p *testResponseComplete) ResponseComplete(_ context.Context, _ *schedulingtypes.LLMRequest, response *Response, targetPod *backend.Pod) {
+ p.lastRespOnComplete = response
+ p.lastTargetPodOnComplete = targetPod.NamespacedName.String()
}
diff --git a/pkg/epp/requestcontrol/plugins.go b/pkg/epp/requestcontrol/plugins.go
index 1bb56062a..30f31f070 100644
--- a/pkg/epp/requestcontrol/plugins.go
+++ b/pkg/epp/requestcontrol/plugins.go
@@ -19,39 +19,41 @@ package requestcontrol
import (
"context"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
const (
- PreRequestExtensionPoint = "PreRequest"
- PostResponseExtensionPoint = "PostResponse"
- PostResponseChunkExtensionPoint = "PostResponseChunk"
- PostResponseCompleteExtensionPoint = "PostResponseComplete"
+ PreRequestExtensionPoint = "PreRequest"
+ ResponseReceivedExtensionPoint = "ResponseReceived"
+ ResponseStreamingExtensionPoint = "ResponseStreaming"
+ ResponseCompleteExtensionPoint = "ResponseComplete"
)
-// PreRequest is called by the director after a getting result from scheduling layer but
+// PreRequest is called by the director after a getting result from scheduling layer and
// before a request is sent to the selected model server.
type PreRequest interface {
plugins.Plugin
- PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, targetPort int)
+ PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult)
}
-// PostResponse is called by the director after a successful response is recieved or first chunk if streaming.
-type PostResponse interface {
+// ResponseReceived is called by the director after the response headers are successfully received
+// which indicates the beginning of the response handling by the model server.
+// The given pod argument is the pod that served the request.
+type ResponseReceived interface {
plugins.Plugin
- PostResponse(ctx context.Context, reqCtx *handlers.RequestContext)
+ ResponseReceived(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
}
-// PostResponseChunk is called by the director if in streaming mode after each successful response chunk.
-type PostResponseChunk interface {
+// ResponseStreaming is called by the director after each chunk of streaming response is sent.
+type ResponseStreaming interface {
plugins.Plugin
- PostResponseChunk(ctx context.Context, reqCtx *handlers.RequestContext)
+ ResponseStreaming(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
}
-// PostResponseComplete is called by the director if in streaming mode after the final successful response chunk is sent.
-type PostResponseComplete interface {
+// ResponseComplete is called by the director after the complete response is sent.
+type ResponseComplete interface {
plugins.Plugin
- PostResponseComplete(ctx context.Context, reqCtx *handlers.RequestContext)
+ ResponseComplete(ctx context.Context, request *types.LLMRequest, response *Response, targetPod *backend.Pod)
}
diff --git a/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go b/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go
deleted file mode 100644
index cc57a9963..000000000
--- a/pkg/epp/requestcontrol/plugins/slorequest/slo_request_tracker.go
+++ /dev/null
@@ -1,177 +0,0 @@
-/*
-Copyright 2025 The Kubernetes Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package slorequest
-
-import (
- "context"
- "time"
-
- "github.com/go-logr/logr"
- "github.com/google/uuid"
- "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/controller-runtime/pkg/log"
-
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- scheduling_types "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
- requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
-)
-
-const (
- SLORequestTrackerPluginType = "slo-request-tracker"
-)
-
-type SLORequestTracker struct {
- tn plugins.TypedName
- latencypredictor latencypredictorasync.PredictorInterface
- datastore datastore.Datastore
-}
-
-var _ requestcontrol.PreRequest = &SLORequestTracker{}
-var _ requestcontrol.PostResponse = &SLORequestTracker{}
-var _ requestcontrol.PostResponseChunk = &SLORequestTracker{}
-var _ requestcontrol.PostResponseComplete = &SLORequestTracker{}
-
-func New(latencypredictor latencypredictorasync.PredictorInterface, datastore datastore.Datastore) *SLORequestTracker {
- return &SLORequestTracker{
- tn: plugins.TypedName{Type: SLORequestTrackerPluginType, Name: SLORequestTrackerPluginType},
- latencypredictor: latencypredictor,
- datastore: datastore,
- }
-}
-
-func (t *SLORequestTracker) TypedName() plugins.TypedName {
- return t.tn
-}
-
-func (s *SLORequestTracker) WithName(name string) *SLORequestTracker {
- s.tn.Name = name
- return s
-}
-
-func (t *SLORequestTracker) PreRequest(ctx context.Context, request *scheduling_types.LLMRequest, schedulingResult *scheduling_types.SchedulingResult, targetPort int) {
- logger := log.FromContext(ctx)
-
- if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 {
- logger.V(logutil.DEBUG).Info("SLORequestTracker: Skipping PreRequest because no scheduling result was provided.")
- return
- }
-
- targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod()
-
- podName := types.NamespacedName{
- Name: targetPod.NamespacedName.Name,
- Namespace: targetPod.NamespacedName.Namespace,
- }
-
- logger.V(logutil.DEBUG).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName)
- if request.Headers[requtil.RequestIdHeaderKey] == "" {
- request.Headers[requtil.RequestIdHeaderKey] = uuid.New().String()
- logger.V(logutil.DEBUG).Info("Generated new request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey])
- logger.V(logutil.DEBUG).Info("request headers for SLO tracking", "requestHeaders", request.Headers)
- }
-
- err := t.datastore.PodAddRequest(podName, request.Headers[requtil.RequestIdHeaderKey], request.AvgTPOTSLO)
- if err != nil {
- logger.V(logutil.DEBUG).Error(err, "SLORequestTracker: Failed to add request to pod running queue", "podName", podName, "requestID", request.Headers[requtil.RequestIdHeaderKey])
- }
-}
-
-func (t *SLORequestTracker) PostResponse(ctx context.Context, reqCtx *handlers.RequestContext) {
- logger := log.FromContext(ctx)
- targetPod := reqCtx.TargetPod
- if !t.CheckPredictor(logger, targetPod) {
- return
- }
-
- if err := requestcontrol.ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, reqCtx); err != nil {
- logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed")
- }
-
-}
-
-func (t *SLORequestTracker) PostResponseChunk(ctx context.Context, reqCtx *handlers.RequestContext) {
- logger := log.FromContext(ctx)
- targetPod := reqCtx.TargetPod
- if !t.CheckPredictor(logger, targetPod) {
- return
- }
-
- now := time.Now()
-
- if reqCtx.TTFT == 0 {
- requestcontrol.ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, reqCtx, now)
- } else {
- requestcontrol.ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, reqCtx, now)
- }
-
-}
-
-func (t *SLORequestTracker) PostResponseComplete(ctx context.Context, reqCtx *handlers.RequestContext) {
- logger := log.FromContext(ctx)
- request := reqCtx.SchedulingRequest
- targetPod := reqCtx.TargetPod
- if !t.CheckPredictor(logger, targetPod) {
- return
- }
-
- if reqCtx.TTFT > 0 {
- logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", reqCtx.TTFT, "avgPredictedTTFT", reqCtx.PredictedTTFT)
- metrics.RecordRequestTTFT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.TTFT/1000)
- metrics.RecordRequestPredictedTTFT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.PredictedTTFT/1000)
- if reqCtx.SchedulingRequest.TTFTSLO > 0 {
- metrics.RecordRequestTTFTWithSLO(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.TTFT, reqCtx.SchedulingRequest.TTFTSLO)
- }
- }
-
- if reqCtx.AvgTPOT > 0 {
- logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", reqCtx.AvgTPOT, "avgPredictedTPOT", reqCtx.AvgPredictedTPOT)
- metrics.RecordRequestTPOT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.AvgTPOT/1000)
- metrics.RecordRequestPredictedTPOT(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.AvgPredictedTPOT/1000)
- if reqCtx.SchedulingRequest.AvgTPOTSLO > 0 {
- metrics.RecordRequestTPOTWithSLO(ctx, reqCtx.IncomingModelName, reqCtx.TargetModelName, reqCtx.AvgTPOT, reqCtx.SchedulingRequest.AvgTPOTSLO)
- }
- }
- logger.V(logutil.DEBUG).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", request.PredictorBasedScheduling)
-
- podName := types.NamespacedName{
- Name: targetPod.NamespacedName.Name,
- Namespace: targetPod.NamespacedName.Namespace,
- }
-
- if err := t.datastore.PodRemoveRequest(podName, request.Headers[requtil.RequestIdHeaderKey]); err != nil {
- logger.V(logutil.DEBUG).Error(err, "SLORequestTracker: Failed to remove request from queue", "requestID", request.Headers[requtil.RequestIdHeaderKey])
- }
-}
-
-func (t *SLORequestTracker) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool {
- if targetPod == nil {
- logger.V(logutil.DEBUG).Info("SLORequestTracker: Skipping PostResponse because no target pod was provided.")
- return false
- }
- if t.latencypredictor == nil {
- logger.V(logutil.DEBUG).Info("SLORequestTracker: Skipping PostResponse because predictor missing")
- return false
- }
- return true
-}
diff --git a/pkg/epp/requestcontrol/request_control_config.go b/pkg/epp/requestcontrol/request_control_config.go
index 32b68a38b..ffa6c6609 100644
--- a/pkg/epp/requestcontrol/request_control_config.go
+++ b/pkg/epp/requestcontrol/request_control_config.go
@@ -23,19 +23,19 @@ import (
// NewConfig creates a new Config object and returns its pointer.
func NewConfig() *Config {
return &Config{
- preRequestPlugins: []PreRequest{},
- postResponsePlugins: []PostResponse{},
- postResponseChunkPlugins: []PostResponseChunk{},
- postResponseCompletePlugins: []PostResponseComplete{},
+ preRequestPlugins: []PreRequest{},
+ responseReceivedPlugins: []ResponseReceived{},
+ responseStreamingPlugins: []ResponseStreaming{},
+ responseCompletePlugins: []ResponseComplete{},
}
}
// Config provides a configuration for the requestcontrol plugins.
type Config struct {
- preRequestPlugins []PreRequest
- postResponsePlugins []PostResponse
- postResponseChunkPlugins []PostResponseChunk
- postResponseCompletePlugins []PostResponseComplete
+ preRequestPlugins []PreRequest
+ responseReceivedPlugins []ResponseReceived
+ responseStreamingPlugins []ResponseStreaming
+ responseCompletePlugins []ResponseComplete
}
// WithPreRequestPlugins sets the given plugins as the PreRequest plugins.
@@ -45,40 +45,44 @@ func (c *Config) WithPreRequestPlugins(plugins ...PreRequest) *Config {
return c
}
-// WithPostResponsePlugins sets the given plugins as the PostResponse plugins.
-// If the Config has PostResponse plugins already, this call replaces the existing plugins with the given ones.
-func (c *Config) WithPostResponsePlugins(plugins ...PostResponse) *Config {
- c.postResponsePlugins = plugins
+// WithResponseReceivedPlugins sets the given plugins as the ResponseReceived plugins.
+// If the Config has ResponseReceived plugins already, this call replaces the existing plugins with the given ones.
+func (c *Config) WithResponseReceivedPlugins(plugins ...ResponseReceived) *Config {
+ c.responseReceivedPlugins = plugins
return c
}
-// WithPostResponsePlugins sets the given plugins as the PostResponse plugins.
-// If the Config has PostResponse plugins already, this call replaces the existing plugins with the given ones.
-func (c *Config) WithPostResponseChunkPlugins(plugins ...PostResponseChunk) *Config {
- c.postResponseChunkPlugins = plugins
+// WithResponseStreamingPlugins sets the given plugins as the ResponseStreaming plugins.
+// If the Config has ResponseStreaming plugins already, this call replaces the existing plugins with the given ones.
+func (c *Config) WithResponseStreamingPlugins(plugins ...ResponseStreaming) *Config {
+ c.responseStreamingPlugins = plugins
return c
}
-// WithPostResponseCompletePlugins sets the given plugins as the PostResponseComplete plugins.
-// If the Config has PostResponseComplete plugins already, this call replaces the existing plugins with the given ones.
-func (c *Config) WithPostResponseCompletePlugins(plugins ...PostResponseComplete) *Config {
- c.postResponseCompletePlugins = plugins
+// WithResponseCompletePlugins sets the given plugins as the ResponseComplete plugins.
+// If the Config has ResponseComplete plugins already, this call replaces the existing plugins with the given ones.
+func (c *Config) WithResponseCompletePlugins(plugins ...ResponseComplete) *Config {
+ c.responseCompletePlugins = plugins
return c
}
+// AddPlugins adds the given plugins to the Config.
+// The type of each plugin is checked and added to the corresponding list of plugins in the Config.
+// If a plugin implements multiple plugin interfaces, it will be added to each corresponding list.
+
func (c *Config) AddPlugins(pluginObjects ...plugins.Plugin) {
for _, plugin := range pluginObjects {
if preRequestPlugin, ok := plugin.(PreRequest); ok {
c.preRequestPlugins = append(c.preRequestPlugins, preRequestPlugin)
}
- if postResponsePlugin, ok := plugin.(PostResponse); ok {
- c.postResponsePlugins = append(c.postResponsePlugins, postResponsePlugin)
+ if responseReceivedPlugin, ok := plugin.(ResponseReceived); ok {
+ c.responseReceivedPlugins = append(c.responseReceivedPlugins, responseReceivedPlugin)
}
- if postResponseChunkPlugin, ok := plugin.(PostResponseChunk); ok {
- c.postResponseChunkPlugins = append(c.postResponseChunkPlugins, postResponseChunkPlugin)
+ if responseStreamingPlugin, ok := plugin.(ResponseStreaming); ok {
+ c.responseStreamingPlugins = append(c.responseStreamingPlugins, responseStreamingPlugin)
}
- if postResponseCompletePlugin, ok := plugin.(PostResponseComplete); ok {
- c.postResponseCompletePlugins = append(c.postResponseCompletePlugins, postResponseCompletePlugin)
+ if responseCompletePlugin, ok := plugin.(ResponseComplete); ok {
+ c.responseCompletePlugins = append(c.responseCompletePlugins, responseCompletePlugin)
}
}
}
diff --git a/pkg/epp/requestcontrol/types.go b/pkg/epp/requestcontrol/types.go
index 8604e1dda..c881ed713 100644
--- a/pkg/epp/requestcontrol/types.go
+++ b/pkg/epp/requestcontrol/types.go
@@ -16,7 +16,7 @@ limitations under the License.
package requestcontrol
-// Response contains information from the response received to be passed to PostResponse plugins
+// Response contains information from the response received to be passed to the Response requestcontrol plugins
type Response struct {
// RequestId is the Envoy generated Id for the request being processed
RequestId string
diff --git a/pkg/epp/saturationdetector/saturationdetector_test.go b/pkg/epp/saturationdetector/saturationdetector_test.go
index 7d46143c3..0b861d90a 100644
--- a/pkg/epp/saturationdetector/saturationdetector_test.go
+++ b/pkg/epp/saturationdetector/saturationdetector_test.go
@@ -26,133 +26,19 @@ import (
"github.com/go-logr/logr"
"github.com/google/go-cmp/cmp"
- corev1 "k8s.io/api/core/v1"
- metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
+ "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
)
-// --- Mock Implementations ---
-
-type mockDatastore struct {
- pods []backendmetrics.PodMetrics
-}
-
-// PodGetAll returns all pod metrics from the fake datastore.
-func (fds *mockDatastore) PodGetAll() []backendmetrics.PodMetrics {
- return fds.pods
-}
-
-func (fds *mockDatastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics {
- res := []backendmetrics.PodMetrics{}
- for _, pm := range fds.pods {
- if predicate(pm) {
- res = append(res, pm)
- }
- }
- return res
-}
-
-// Helper function to create a properly initialized fake pod metrics
-func newMockPodMetrics(name string, metrics *backendmetrics.MetricsState) backendmetrics.PodMetrics {
- // Create a proper k8s pod
- k8sPod := &corev1.Pod{
- ObjectMeta: metav1.ObjectMeta{
- Name: name,
- Namespace: "ns1",
- Labels: map[string]string{"app": "test"},
- },
- Status: corev1.PodStatus{
- PodIP: "192.168.1.1",
+func newMockPodMetrics(name string, metrics *backendmetrics.MetricsState) *backendmetrics.FakePodMetrics {
+ return &backendmetrics.FakePodMetrics{
+ Pod: &backend.Pod{
+ NamespacedName: types.NamespacedName{Name: name, Namespace: "ns1"},
},
+ Metrics: metrics,
}
-
- // Use the proper constructor
- fakePodMetrics := backendmetrics.NewFakePodMetrics(k8sPod)
-
- // Create a custom fake that can return the specified metrics
- return &testPodMetrics{
- FakePodMetrics: fakePodMetrics,
- customMetrics: metrics,
- }
-}
-
-// testPodMetrics wraps FakePodMetrics to allow custom metrics for testing
-type testPodMetrics struct {
- *backendmetrics.FakePodMetrics
- customMetrics *backendmetrics.MetricsState
-}
-
-// AddRequest implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).AddRequest of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) AddRequest(requestID string, tpot float64) bool {
- panic("unimplemented")
-}
-
-// ContainsRequest implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).ContainsRequest of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) ContainsRequest(requestID string) bool {
- panic("unimplemented")
-}
-
-// GetPod implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).GetPod of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) GetPod() *backend.Pod {
- return t.FakePodMetrics.GetPod()
-}
-
-// GetRequestCount implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).GetRequestCount of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) GetRequestCount() int {
- panic("unimplemented")
-}
-
-// GetRunningRequests implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).GetRunningRequests of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) GetRunningRequests() *datalayer.RequestPriorityQueue {
- panic("unimplemented")
-}
-
-// PeekRequestPriorityQueue implements metrics.PodMetrics.
-func (t *testPodMetrics) PeekRequestPriorityQueue() *datalayer.Request {
- panic("unimplemented")
-}
-
-// RemoveRequest implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).RemoveRequest of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) RemoveRequest(requestID string) bool {
- panic("unimplemented")
-}
-
-// StopRefreshLoop implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).StopRefreshLoop of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) StopRefreshLoop() {
- panic("unimplemented")
-}
-
-// String implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).String of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) String() string {
- panic("unimplemented")
-}
-
-// UpdatePod implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).UpdatePod of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) UpdatePod(*corev1.Pod) {
- panic("unimplemented")
-}
-
-// UpdateRequest implements metrics.PodMetrics.
-// Subtle: this method shadows the method (*FakePodMetrics).UpdateRequest of testPodMetrics.FakePodMetrics.
-func (t *testPodMetrics) UpdateRequest(requestID string, tpot float64) bool {
- panic("unimplemented")
-}
-
-// Override GetMetrics to return custom metrics for testing
-func (t *testPodMetrics) GetMetrics() *backendmetrics.MetricsState {
- return t.customMetrics // Return exactly what was passed, including nil
}
// --- Tests ---
@@ -228,16 +114,16 @@ func TestDetector_IsSaturated(t *testing.T) {
}
tests := []struct {
- name string
- config *Config
- pods []backendmetrics.PodMetrics
- expectedSaturat bool
+ name string
+ config *Config
+ pods []backendmetrics.PodMetrics
+ expectedSaturation bool
}{
{
- name: "No pods in datastore",
- config: defaultConfig,
- pods: []backendmetrics.PodMetrics{},
- expectedSaturat: true, // No capacity = saturated
+ name: "No candidate pods",
+ config: defaultConfig,
+ pods: []backendmetrics.PodMetrics{},
+ expectedSaturation: true, // No capacity = saturated
},
{
name: "Single pod with good capacity",
@@ -247,11 +133,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: 2,
KVCacheUsagePercent: 0.5,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: false,
+ expectedSaturation: false,
},
{
name: "Single pod with stale metrics",
@@ -261,11 +145,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime.Add(-200 * time.Millisecond), // Stale
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: true,
+ expectedSaturation: true,
},
{
name: "Single pod with high queue depth",
@@ -275,11 +157,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: 10, // Exceeds threshold 5
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: true,
+ expectedSaturation: true,
},
{
name: "Single pod with high KV cache utilization",
@@ -289,11 +169,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.95, // Exceeds threshold 0.90
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: true,
+ expectedSaturation: true,
},
{
name: "Single pod with nil metrics",
@@ -301,7 +179,7 @@ func TestDetector_IsSaturated(t *testing.T) {
pods: []backendmetrics.PodMetrics{
newMockPodMetrics("pod1", nil),
},
- expectedSaturat: true,
+ expectedSaturation: true,
},
{
name: "Multiple pods, all good capacity",
@@ -311,18 +189,14 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
newMockPodMetrics("pod2", &backendmetrics.MetricsState{
UpdateTime: baseTime.Add(-10 * time.Millisecond),
WaitingQueueSize: 0,
KVCacheUsagePercent: 0.2,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: false,
+ expectedSaturation: false,
},
{
name: "Multiple pods, one good, one bad (stale)",
@@ -332,18 +206,14 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime, // Good
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
newMockPodMetrics("pod2", &backendmetrics.MetricsState{
UpdateTime: baseTime.Add(-300 * time.Millisecond), // Stale
WaitingQueueSize: 0,
KVCacheUsagePercent: 0.2,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: false, // One good pod is enough
+ expectedSaturation: false, // One good pod is enough
},
{
name: "Multiple pods, one good, one bad (high queue)",
@@ -353,18 +223,14 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
newMockPodMetrics("pod2", &backendmetrics.MetricsState{
UpdateTime: baseTime,
WaitingQueueSize: 15, // Bad queue
KVCacheUsagePercent: 0.2,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: false,
+ expectedSaturation: false,
},
{
name: "Multiple pods, all bad capacity",
@@ -374,25 +240,19 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime.Add(-200 * time.Millisecond), // Stale
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
newMockPodMetrics("pod2", &backendmetrics.MetricsState{
UpdateTime: baseTime,
WaitingQueueSize: 20, // High queue
KVCacheUsagePercent: 0.2,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
newMockPodMetrics("pod3", &backendmetrics.MetricsState{
UpdateTime: baseTime,
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.99, // High KV
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: true,
+ expectedSaturation: true,
},
{
name: "Queue depth exactly at threshold",
@@ -402,11 +262,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: defaultConfig.QueueDepthThreshold, // Exactly at threshold (good)
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: false,
+ expectedSaturation: false,
},
{
name: "KV cache exactly at threshold",
@@ -416,11 +274,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime,
WaitingQueueSize: 1,
KVCacheUsagePercent: defaultConfig.KVCacheUtilThreshold, // Exactly at threshold (good)
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: false,
+ expectedSaturation: false,
},
{
name: "Metrics age just over staleness threshold",
@@ -430,11 +286,9 @@ func TestDetector_IsSaturated(t *testing.T) {
UpdateTime: baseTime.Add(-defaultConfig.MetricsStalenessThreshold - time.Nanosecond), // Just over (stale)
WaitingQueueSize: 1,
KVCacheUsagePercent: 0.1,
- ActiveModels: make(map[string]int),
- WaitingModels: make(map[string]int),
}),
},
- expectedSaturat: true,
+ expectedSaturation: true,
},
}
@@ -442,8 +296,8 @@ func TestDetector_IsSaturated(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
detector := NewDetector(test.config, logr.Discard())
- if got := detector.IsSaturated(context.Background(), test.pods); got != test.expectedSaturat {
- t.Errorf("IsSaturated() = %v, want %v", got, test.expectedSaturat)
+ if got := detector.IsSaturated(context.Background(), test.pods); got != test.expectedSaturation {
+ t.Errorf("IsSaturated() = %v, want %v", got, test.expectedSaturation)
}
})
}
diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go
index bd9e2c96e..8b68132dc 100644
--- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go
@@ -149,3 +149,35 @@ func (i *indexer) reportLRUSize(ctx context.Context, interval time.Duration) {
i.mu.RUnlock()
}
}
+
+// RemovePod removes a pod and its associated entries from the indexer.
+func (i *indexer) RemovePod(pod ServerID) {
+ i.mu.RLock()
+ lruCache, exists := i.podToLRU[pod]
+ i.mu.RUnlock()
+
+ if !exists {
+ return
+ }
+
+ // Remove all hashes associated with the pod from hashToPods (triggers eviction callbacks).
+ for _, hash := range lruCache.Keys() {
+ lruCache.Remove(hash)
+ }
+
+ i.mu.Lock()
+ delete(i.podToLRU, pod)
+ i.mu.Unlock()
+}
+
+// Pods returns the list of all pods currently tracked in the indexer.
+func (i *indexer) Pods() []ServerID {
+ i.mu.RLock()
+ defer i.mu.RUnlock()
+
+ pods := make([]ServerID, 0, len(i.podToLRU))
+ for pod := range i.podToLRU {
+ pods = append(pods, pod)
+ }
+ return pods
+}
diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go
index 6d4fcc5f4..c35af8e27 100644
--- a/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go
@@ -46,3 +46,63 @@ func TestIndexer_AddAndGet(t *testing.T) {
servers = i.Get(BlockHash(4))
assert.Empty(t, servers, "Cache should not contain non-existent hash")
}
+
+func TestIndexer_RemovePodAndEviction(t *testing.T) {
+ const indexerSize = 10
+
+ i := newIndexer(context.Background(), indexerSize)
+
+ server1 := ServerID{Namespace: "default", Name: "server1"}
+ server2 := ServerID{Namespace: "default", Name: "server2"}
+
+ // Add indexerSize hashes to both servers
+ var hashes []BlockHash
+ for j := 0; j < indexerSize; j++ {
+ h := BlockHash(j)
+ hashes = append(hashes, h)
+ i.Add([]BlockHash{h}, server1)
+ i.Add([]BlockHash{h}, server2)
+ }
+
+ // Ensure all entries are added
+ assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 should have 10 entries")
+ assert.Equal(t, indexerSize, i.podToLRU[server2].Len(), "server2 should have 10 entries")
+
+ // Ensure each hash in hashToPods maps to both server1 and server2
+ for _, h := range hashes {
+ pods := i.hashToPods[h]
+ assert.Len(t, pods, 2, "Each hash should be associated with exactly 2 pods")
+ assert.Contains(t, pods, server1, "hash should be associated with server1")
+ assert.Contains(t, pods, server2, "hash should be associated with server2")
+ }
+
+ // Add indexerSize hash to server1 → should evict BlockHash(0)
+ evictedHash := BlockHash(0)
+ newHash := BlockHash(indexerSize)
+ i.Add([]BlockHash{newHash}, server1)
+
+ // server1 LRU should still be at max capacity
+ assert.Equal(t, indexerSize, i.podToLRU[server1].Len(), "server1 LRU should maintain max size")
+
+ // BlockHash(0) should no longer have server1 in hashToPods
+ pods := i.Get(evictedHash)
+ assert.NotContains(t, pods, server1, "server1 should be evicted from hashToPods for hash 0")
+ assert.Contains(t, pods, server2, "server2 should still have hash 0")
+
+ // Remove server2
+ i.RemovePod(server2)
+
+ // hashToPods for hash 0 should now be empty
+ pods = i.Get(evictedHash)
+ assert.NotContains(t, pods, server2, "server2 should be removed from hash 0")
+ assert.Empty(t, pods, "hash 0 should have no pods after both eviction and removal")
+
+ // All remaining hashes should map only to server1
+ for hash, pods := range i.hashToPods {
+ assert.Len(t, pods, 1, "hash %v should have only 1 pod after server2 removal", hash)
+ assert.Contains(t, pods, server1, "hash %v should only contain server1", hash)
+ }
+
+ // Ensure hashToPods contains exactly indexerSize hashes (post-eviction and server2 removal)
+ assert.Len(t, i.hashToPods, indexerSize, "hashToPods should contain %d hashes after cleanup", indexerSize)
+}
diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
index c87f8e8bf..d0986e25c 100644
--- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go
@@ -22,11 +22,13 @@ import (
"encoding/json"
"fmt"
"sync"
+ "time"
"github.com/cespare/xxhash/v2"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/log"
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
@@ -37,7 +39,7 @@ import (
const (
// vLLM default token block size is 16, and a good guess of average characters per token is 4.
- DefaultHashBlockSize = 64
+ DefaultBlockSize = 64
// The maximum number of blocks to match. Two long requests with the same prefix up to this
// limit will be indistinguishable.
// This parameter provides a trade-off between cache size, prefix matching speed and matching
@@ -57,16 +59,23 @@ const (
PrefixCachePluginType = "prefix-cache-scorer"
)
+const (
+ PodActiveCheckInterval = 2 * time.Minute
+
+ // An estimated average characters per token, used since the request we cached is not tokenized.
+ averageCharactersPerToken = 4
+)
+
var DefaultConfig = Config{
- HashBlockSize: DefaultHashBlockSize,
+ DefaultBlockSize: DefaultBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
type Config struct {
- // The input prompt is broken into sizes of HashBlockSize to calculate block hashes . Requests
+ // The input prompt is broken into sizes of BlockSize to calculate block hashes . Requests
// with length shorter than the block size will be ignored.
- HashBlockSize int `json:"hashBlockSize"`
+ DefaultBlockSize int `json:"blockSize"`
// MaxPrefixBlocksToMatch is the maximum number of prefix blocks to match. Input beyond this limit will
// be ignored.
MaxPrefixBlocksToMatch int `json:"maxPrefixBlocksToMatch"`
@@ -93,6 +102,8 @@ type podSet map[ServerID]struct{}
type Indexer interface {
Get(hash BlockHash) podSet
Add(hashes []BlockHash, server ServerID)
+ RemovePod(server ServerID)
+ Pods() []ServerID
}
// BlockHash is a hash of the block of request body.
@@ -130,13 +141,15 @@ func (s *SchedulingContextState) Clone() plugins.StateData {
}
// compile-time type assertion
-var _ framework.Scorer = &Plugin{}
-var _ requestcontrol.PreRequest = &Plugin{}
+var (
+ _ framework.Scorer = &Plugin{}
+ _ requestcontrol.PreRequest = &Plugin{}
+)
// PrefixCachePluginFactory defines the factory function for Prefix plugin.
func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
parameters := Config{
- HashBlockSize: DefaultHashBlockSize,
+ DefaultBlockSize: DefaultBlockSize,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
@@ -147,7 +160,9 @@ func PrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle
}
}
- return New(handle.Context(), parameters).WithName(name), nil
+ p := New(handle.Context(), parameters).WithName(name)
+ go p.CleanUpInactivePods(handle.Context(), handle)
+ return p, nil
}
// New initializes a new prefix Plugin and returns its pointer.
@@ -182,9 +197,8 @@ func (p *Plugin) WithName(name string) *Plugin {
// Score returns the scoring result for the given list of pods based on context.
func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
-
// pre score step, hashing prompt and find longest prefix match.
- hashes := hashPrompt(ctx, request, p.config.HashBlockSize, p.config.MaxPrefixBlocksToMatch)
+ hashes := hashPrompt(ctx, request, getBlockSize(pods, p.config.DefaultBlockSize), p.config.MaxPrefixBlocksToMatch)
state := &SchedulingContextState{
PrefixHashes: hashes,
PrefixCacheServers: p.matchLongestPrefix(ctx, hashes),
@@ -212,7 +226,7 @@ func (p *Plugin) Score(ctx context.Context, cycleState *types.CycleState, reques
}
// PreRequest records in the plugin cache the result of the scheduling selection.
-func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) {
+func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
primaryProfileResult := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
targetPod := primaryProfileResult.TargetPods[0].GetPod() // get the first pod of the primary profile
@@ -235,7 +249,9 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
total := len(state.PrefixHashes)
matchLen := state.PrefixCacheServers[ServerID(targetPod.NamespacedName)]
- metrics.RecordPrefixCacheMatch(matchLen*p.config.HashBlockSize, total*p.config.HashBlockSize)
+
+ blockSize := getBlockSize(primaryProfileResult.TargetPods, p.config.DefaultBlockSize)
+ metrics.RecordPrefixCacheMatch(matchLen*blockSize, total*blockSize)
}
// matchLongestPrefix returns a map of servers and length of prefix that each server caches.
@@ -254,45 +270,81 @@ func (p *Plugin) matchLongestPrefix(ctx context.Context, hashes []BlockHash) map
for server := range cachedServers {
// Update servers with their longest prefix match.
res[server]++
-
}
}
}
return res
}
+// CleanUpInactivePods starts a goroutine that watches for inactive pods.
+func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) {
+ logger := log.FromContext(ctx).V(logutil.VERBOSE)
+ ticker := time.NewTicker(PodActiveCheckInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ activePodMetrics := handle.PodList(func(_ backendmetrics.PodMetrics) bool { return true })
+ activePods := make(map[ServerID]struct{}, len(activePodMetrics))
+ for _, pm := range activePodMetrics {
+ activePods[ServerID(pm.GetPod().NamespacedName)] = struct{}{}
+ }
+
+ for _, pod := range m.indexer.Pods() {
+ if _, ok := activePods[pod]; !ok {
+ m.indexer.RemovePod(pod)
+ logger.Info("Removed pod not in active set", "pod", pod)
+ }
+ }
+ }
+ }
+}
+
// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
-// hash(0) is the hash of the model name, since different models generally don't share prefix cache.
+// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
// For block i, hash(i) = hash(block i content, hash(i-1)).
func hashPrompt(ctx context.Context, request *types.LLMRequest, cacheBlockSize int, maxPrefixBlocks int) []BlockHash {
loggerDebug := log.FromContext(ctx).V(logutil.DEBUG)
- prompt := []byte(request.Prompt)
- if len(prompt) < cacheBlockSize {
- loggerDebug.Info("Request body too small for prefix cache", "size", len(prompt), "block size", cacheBlockSize)
+ if request == nil || request.Body == nil {
+ loggerDebug.Info("Request or request data is nil, skipping hashing")
return nil
}
- if len(prompt) > cacheBlockSize*maxPrefixBlocks {
- loggerDebug.Info("Truncating input", "size", len(prompt), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
- prompt = prompt[:maxPrefixBlocks*cacheBlockSize]
+
+ userInput, err := getUserInputBytes(request)
+ if err != nil {
+ loggerDebug.Error(err, "Failed to get user input bytes")
+ return nil
}
- // Split the body into blocks of size cacheBlockSize. The +1 is to account for the model.
+
+ if len(userInput) < cacheBlockSize {
+ loggerDebug.Info("Request body too small for prefix cache", "size", len(userInput), "block size", cacheBlockSize)
+ return nil
+ }
+ if len(userInput) > cacheBlockSize*maxPrefixBlocks {
+ loggerDebug.Info("Truncating input", "size", len(userInput), "max prefix blocks", maxPrefixBlocks, "block size", cacheBlockSize)
+ userInput = userInput[:maxPrefixBlocks*cacheBlockSize]
+ }
+ // Split the body into blocks of size cacheBlockSize.
// If the last block is smaller than cacheBlockSize, it will be ignored.
- res := make([]BlockHash, 0, 1+len(prompt)/cacheBlockSize)
+ res := make([]BlockHash, 0, len(userInput)/cacheBlockSize)
// Add the model to the first block hash so that different models have different hashes even with the same body.
-
- firstBlockSize := cacheBlockSize
- if len(prompt) < cacheBlockSize {
- firstBlockSize = len(prompt)
+ h := xxhash.New()
+ _, _ = h.Write([]byte(request.TargetModel))
+ if cacheSalt := request.Body.CacheSalt(); cacheSalt != "" {
+ _, _ = h.Write([]byte(cacheSalt))
}
- firstBlock := prompt[0:firstBlockSize]
- firstBlockWithModel := append([]byte(request.TargetModel), firstBlock...)
- res = append(res, BlockHash(xxhash.Sum64(firstBlockWithModel)))
-
- for i := cacheBlockSize; i+cacheBlockSize <= len(prompt); i += cacheBlockSize {
- block := prompt[i : i+cacheBlockSize]
- prevBlockHash := res[len(res)-1]
- block = append(block, toBytes(prevBlockHash)...)
- res = append(res, BlockHash(xxhash.Sum64(block)))
+
+ prevBlockHash := BlockHash(h.Sum64())
+ for i := 0; i+cacheBlockSize <= len(userInput); i += cacheBlockSize {
+ h.Reset()
+ _, _ = h.Write(userInput[i : i+cacheBlockSize])
+ _, _ = h.Write(toBytes(prevBlockHash))
+ res = append(res, BlockHash(h.Sum64()))
+
+ prevBlockHash = res[len(res)-1]
}
return res
}
@@ -302,3 +354,28 @@ func toBytes(i BlockHash) []byte {
binary.LittleEndian.PutUint64(bytes, uint64(i))
return bytes
}
+
+func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
+ if request.Body.Completions != nil { // assumed to be valid if not nil
+ return []byte(request.Body.Completions.Prompt), nil
+ }
+
+ // must be chat-completions request at this point, return bytes of entire messages
+ return json.Marshal(request.Body.ChatCompletions.Messages)
+}
+
+func getBlockSize(pods []types.Pod, defaultBlockSize int) int {
+ if len(pods) == 0 {
+ return defaultBlockSize
+ }
+
+ // Since all PODs originate from the same inference pool, they are considered to have identical configurations.
+ // Therefore, using the CacheBlockSize value from the first POD suffices.
+ if pod := pods[0]; pod.GetMetrics() != nil {
+ cacheBlockSize := pod.GetMetrics().CacheBlockSize
+ if cacheBlockSize > 0 {
+ return cacheBlockSize * averageCharactersPerToken
+ }
+ }
+ return defaultBlockSize
+}
diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
index 3fbac2ce1..59a09db52 100644
--- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin_test.go
@@ -19,7 +19,6 @@ package prefix
import (
"context"
"fmt"
- "math"
"math/rand"
"strings"
"testing"
@@ -29,28 +28,32 @@ import (
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)
-func TestPrefixPlugin(t *testing.T) {
-
+func TestPrefixPluginCompletion(t *testing.T) {
config := Config{
- HashBlockSize: 4,
+ DefaultBlockSize: 4,
MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
plugin := New(context.Background(), config)
- pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
- pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}
+ pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: backendmetrics.NewMetricsState()}
+ pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: backendmetrics.NewMetricsState()}
pods := []types.Pod{pod1, pod2}
// First request.
req1 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model1",
- Prompt: "aaaaaa",
+ Body: &types.LLMRequestBody{
+ Completions: &types.CompletionsRequest{
+ Prompt: "aaaaaa",
+ },
+ },
}
scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
@@ -70,7 +73,7 @@ func TestPrefixPlugin(t *testing.T) {
"default": {TargetPods: []types.Pod{pod1}},
},
}
- plugin.PreRequest(context.Background(), req1, schedulingResult, 0)
+ plugin.PreRequest(context.Background(), req1, schedulingResult)
plugin.wg.Wait()
// Second request doesn't share any prefix with first one. It should be added to the cache but
@@ -78,7 +81,11 @@ func TestPrefixPlugin(t *testing.T) {
req2 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model2",
- Prompt: "bbbbbb",
+ Body: &types.LLMRequestBody{
+ Completions: &types.CompletionsRequest{
+ Prompt: "bbbbbb",
+ },
+ },
}
scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
@@ -98,14 +105,18 @@ func TestPrefixPlugin(t *testing.T) {
"default": {TargetPods: []types.Pod{pod2}},
},
}
- plugin.PreRequest(context.Background(), req2, schedulingResult, 0)
+ plugin.PreRequest(context.Background(), req2, schedulingResult)
plugin.wg.Wait()
// Third request shares partial prefix with first one.
req3 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model1",
- Prompt: "aaaabbbb",
+ Body: &types.LLMRequestBody{
+ Completions: &types.CompletionsRequest{
+ Prompt: "aaaabbbb",
+ },
+ },
}
scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
@@ -124,14 +135,18 @@ func TestPrefixPlugin(t *testing.T) {
"default": {TargetPods: []types.Pod{pod1}},
},
}
- plugin.PreRequest(context.Background(), req3, schedulingResult, 0)
+ plugin.PreRequest(context.Background(), req3, schedulingResult)
plugin.wg.Wait()
// 4th request is same as req3 except the model is different, still no match.
req4 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model-new",
- Prompt: "aaaabbbb",
+ Body: &types.LLMRequestBody{
+ Completions: &types.CompletionsRequest{
+ Prompt: "aaaabbbb",
+ },
+ },
}
scores = plugin.Score(context.Background(), types.NewCycleState(), req4, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req4.RequestId, plugins.StateKey(plugin.TypedName().String()))
@@ -150,14 +165,18 @@ func TestPrefixPlugin(t *testing.T) {
"default": {TargetPods: []types.Pod{pod1}},
},
}
- plugin.PreRequest(context.Background(), req4, schedulingResult, 0)
+ plugin.PreRequest(context.Background(), req4, schedulingResult)
plugin.wg.Wait()
// 5th request shares partial prefix with 3rd one.
req5 := &types.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "test-model1",
- Prompt: "aaaabbbbcccc",
+ Body: &types.LLMRequestBody{
+ Completions: &types.CompletionsRequest{
+ Prompt: "aaaabbbbcccc",
+ },
+ },
}
scores = plugin.Score(context.Background(), types.NewCycleState(), req5, pods)
state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req5.RequestId, plugins.StateKey(plugin.TypedName().String()))
@@ -176,8 +195,153 @@ func TestPrefixPlugin(t *testing.T) {
"default": {TargetPods: []types.Pod{pod1}},
},
}
- plugin.PreRequest(context.Background(), req5, schedulingResult, 0)
+ plugin.PreRequest(context.Background(), req5, schedulingResult)
+ plugin.wg.Wait()
+}
+
+func TestPrefixPluginChatCompletions(t *testing.T) {
+ config := Config{
+ DefaultBlockSize: 4,
+ MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
+ LRUCapacityPerServer: DefaultLRUCapacityPerServer,
+ }
+ plugin := New(context.Background(), config)
+
+ pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}}
+ pods := []types.Pod{pod1}
+
+ // Test with chat completions request
+ req1 := &types.LLMRequest{
+ RequestId: uuid.NewString(),
+ TargetModel: "test-model1",
+ Body: &types.LLMRequestBody{
+ ChatCompletions: &types.ChatCompletionsRequest{
+ Messages: []types.Message{
+ {Role: "user", Content: types.Content{Raw: "hello world"}},
+ {Role: "assistant", Content: types.Content{Raw: "hi there"}},
+ },
+ },
+ },
+ }
+ scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
+ state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
+ assert.NoError(t, err)
+ t.Logf("Chat completions - Hashes %+v, cached servers: %+v", state.PrefixHashes, state.PrefixCacheServers)
+ // Should have some hashes for the JSON-encoded messages
+ assert.Greater(t, len(state.PrefixHashes), 1, "should have hashes for chat completions")
+ assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers initially")
+ assert.Equal(t, float64(0), scores[pod1], "score for pod1")
+}
+
+func TestPrefixPluginChatCompletionsGrowth(t *testing.T) {
+ config := Config{
+ DefaultBlockSize: 8, // Use larger block size for more predictable JSON marshaling
+ MaxPrefixBlocksToMatch: DefaultMaxPrefixBlocks,
+ LRUCapacityPerServer: DefaultLRUCapacityPerServer,
+ }
+ plugin := New(context.Background(), config)
+
+ pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, MetricsState: &backendmetrics.MetricsState{}}
+ pod2 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, MetricsState: &backendmetrics.MetricsState{}}
+ pods := []types.Pod{pod1, pod2}
+
+ // First request with initial conversation
+ req1 := &types.LLMRequest{
+ RequestId: uuid.NewString(),
+ TargetModel: "test-model1",
+ Body: &types.LLMRequestBody{
+ ChatCompletions: &types.ChatCompletionsRequest{
+ Messages: []types.Message{
+ {Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
+ {Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
+ },
+ },
+ },
+ }
+ scores := plugin.Score(context.Background(), types.NewCycleState(), req1, pods)
+ state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req1.RequestId, plugins.StateKey(plugin.TypedName().String()))
+ assert.NoError(t, err)
+ t.Logf("Initial conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers)
+ initialHashCount := len(state.PrefixHashes)
+ assert.Greater(t, initialHashCount, 1, "should have hashes for chat completions")
+ assert.Equal(t, 0, len(state.PrefixCacheServers), "there shouldn't be any cached servers initially")
+ assert.Equal(t, float64(0), scores[pod1], "score for pod1")
+ assert.Equal(t, float64(0), scores[pod2], "score for pod2")
+
+ // Simulate pod1 was picked
+ schedulingResult := &types.SchedulingResult{
+ PrimaryProfileName: "default",
+ ProfileResults: map[string]*types.ProfileRunResult{
+ "default": {TargetPods: []types.Pod{pod1}},
+ },
+ }
+ plugin.PreRequest(context.Background(), req1, schedulingResult)
+ plugin.wg.Wait()
+
+ // Second request adds assistant response and new user message (conversation grows)
+ req2 := &types.LLMRequest{
+ RequestId: uuid.NewString(),
+ TargetModel: "test-model1",
+ Body: &types.LLMRequestBody{
+ ChatCompletions: &types.ChatCompletionsRequest{
+ Messages: []types.Message{
+ {Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
+ {Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
+ {Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
+ {Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
+ },
+ },
+ },
+ }
+ scores = plugin.Score(context.Background(), types.NewCycleState(), req2, pods)
+ state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req2.RequestId, plugins.StateKey(plugin.TypedName().String()))
+ assert.NoError(t, err)
+ t.Logf("Extended conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers)
+ extendedHashCount := len(state.PrefixHashes)
+ assert.Greater(t, extendedHashCount, initialHashCount, "extended conversation should have more hashes")
+ assert.Greater(t, len(state.PrefixCacheServers), 0, "should have cached servers from prefix match")
+
+ // Calculate expected score - pod1 should have cached the initial prefix
+ cachedBlocks := state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)]
+ expectedScore := float64(cachedBlocks) / float64(extendedHashCount)
+ assert.Equal(t, expectedScore, scores[pod1], "pod1 should have prefix cache hit")
+ assert.Equal(t, float64(0), scores[pod2], "pod2 should have no cache hit")
+
+ // Simulate pod1 was picked again
+ plugin.PreRequest(context.Background(), req2, schedulingResult)
plugin.wg.Wait()
+
+ // Third request continues the conversation even further
+ req3 := &types.LLMRequest{
+ RequestId: uuid.NewString(),
+ TargetModel: "test-model1",
+ Body: &types.LLMRequestBody{
+ ChatCompletions: &types.ChatCompletionsRequest{
+ Messages: []types.Message{
+ {Role: "system", Content: types.Content{Raw: "You are a helpful assistant"}},
+ {Role: "user", Content: types.Content{Raw: "Hello, how are you?"}},
+ {Role: "assistant", Content: types.Content{Raw: "I'm doing well, thank you! How can I help you today?"}},
+ {Role: "user", Content: types.Content{Raw: "Can you explain how prefix caching works?"}},
+ {Role: "assistant", Content: types.Content{Raw: "Prefix caching is a technique where..."}},
+ {Role: "user", Content: types.Content{Raw: "That's very helpful, thank you!"}},
+ },
+ },
+ },
+ }
+ scores = plugin.Score(context.Background(), types.NewCycleState(), req3, pods)
+ state, err = plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req3.RequestId, plugins.StateKey(plugin.TypedName().String()))
+ assert.NoError(t, err)
+ t.Logf("Long conversation - Hashes %+v, cached servers: %+v", len(state.PrefixHashes), state.PrefixCacheServers)
+ longHashCount := len(state.PrefixHashes)
+ assert.Greater(t, longHashCount, extendedHashCount, "long conversation should have even more hashes")
+ assert.Greater(t, len(state.PrefixCacheServers), 0, "should have cached servers from prefix match")
+
+ // pod1 should have an even higher cache hit rate now
+ cachedBlocks = state.PrefixCacheServers[ServerID(pod1.GetPod().NamespacedName)]
+ expectedScore = float64(cachedBlocks) / float64(longHashCount)
+ assert.Equal(t, expectedScore, scores[pod1], "pod1 should have higher prefix cache hit")
+ assert.Greater(t, scores[pod1], float64(0.5), "cache hit rate should be substantial for growing conversation")
+ assert.Equal(t, float64(0), scores[pod2], "pod2 should still have no cache hit")
}
// TestPrefixPluginStress is a stress test for the prefix scoring plugin, using prompts of increasing length.
@@ -185,7 +349,7 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
blockSize := 4
maxPrefixBlocks := 50000
config := Config{
- HashBlockSize: blockSize,
+ DefaultBlockSize: blockSize,
MaxPrefixBlocksToMatch: maxPrefixBlocks,
LRUCapacityPerServer: DefaultLRUCapacityPerServer,
}
@@ -193,45 +357,44 @@ func BenchmarkPrefixPluginStress(b *testing.B) {
plugin := New(context.Background(), config)
types.NewCycleState()
var promptLen []int
- for i := 1; i <= 1024; i++ {
+ for i := 1; i <= 1024; {
promptLen = append(promptLen, i)
+ i += 10
}
promptLen = append(promptLen, 2048, 4096, 8192, 10000, 20000, 50000)
- for _, i := range promptLen {
- // Generate increasing-length random prompts
- prompt := randomPrompt(4 + i)
- pod := &types.PodMetrics{
- Pod: &backend.Pod{
- NamespacedName: k8stypes.NamespacedName{
- Name: fmt.Sprintf("random-pod-%d", i),
+ for i, v := range promptLen {
+ b.Run(fmt.Sprintf("messages_%d_length_%d", i, v), func(b *testing.B) {
+ // Generate increasing-length random prompts
+ prompt := randomPrompt(4 + v)
+ pod := &types.PodMetrics{
+ Pod: &backend.Pod{
+ NamespacedName: k8stypes.NamespacedName{
+ Name: fmt.Sprintf("random-pod-%d", v),
+ },
},
- },
- }
-
- pods := []types.Pod{pod}
- req := &types.LLMRequest{
- RequestId: uuid.NewString(),
- TargetModel: "model-stress",
- Prompt: prompt,
- }
-
- // First cycle: simulate scheduling and insert prefix info into the cache
- plugin.Score(context.Background(), types.NewCycleState(), req, pods)
- schedulingResult := &types.SchedulingResult{
- PrimaryProfileName: "default",
- ProfileResults: map[string]*types.ProfileRunResult{
- "default": {TargetPods: []types.Pod{pod}},
- },
- }
- plugin.PreRequest(context.Background(), req, schedulingResult, 0)
- plugin.wg.Wait()
+ }
+
+ pods := []types.Pod{pod}
+ req := &types.LLMRequest{
+ RequestId: uuid.NewString(),
+ TargetModel: "model-stress",
+ Body: &types.LLMRequestBody{
+ Completions: &types.CompletionsRequest{
+ Prompt: prompt,
+ },
+ },
+ }
+
+ b.ResetTimer()
+ // Benchmark the scoring operation
+ scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
+ _ = scores // Use the result to prevent optimization
+
+ // Clean up state for next iteration
+ plugin.pluginState.Delete(req.RequestId)
+ })
- // Second cycle: validate internal state
- state, err := plugins.ReadPluginStateKey[*SchedulingContextState](plugin.pluginState, req.RequestId, plugins.StateKey(plugin.TypedName().String()))
- assert.NoError(b, err)
- expectedHashes := int(math.Min(float64(maxPrefixBlocks), float64(len(req.Prompt)/blockSize)))
- assert.Equal(b, expectedHashes, len(state.PrefixHashes), "number of hashes is incorrect")
}
}
@@ -244,3 +407,75 @@ func randomPrompt(n int) string {
}
return sb.String()
}
+
+// BenchmarkPrefixPluginChatCompletionsStress is a stress test for chat completions with varying message counts and lengths
+func BenchmarkPrefixPluginChatCompletionsStress(b *testing.B) {
+ blockSize := 8
+ maxPrefixBlocks := 50000
+ config := Config{
+ DefaultBlockSize: blockSize,
+ MaxPrefixBlocksToMatch: maxPrefixBlocks,
+ LRUCapacityPerServer: DefaultLRUCapacityPerServer,
+ }
+ plugin := New(context.Background(), config)
+
+ // Test scenarios: varying number of messages and message lengths
+ scenarios := []struct {
+ messageCount int
+ messageLength int
+ }{
+ {2, 50}, // Short conversation, short messages
+ {2, 500}, // Short conversation, long messages
+ {5, 100}, // Medium conversation, medium messages
+ {10, 200}, // Long conversation, medium messages
+ {20, 100}, // Very long conversation, medium messages
+ {50, 50}, // Very long conversation, short messages
+ {100, 25}, // Extremely long conversation, very short messages
+ }
+
+ for _, scenario := range scenarios {
+ b.Run(fmt.Sprintf("messages_%d_length_%d", scenario.messageCount, scenario.messageLength), func(b *testing.B) {
+ // Generate messages for this scenario
+ messages := make([]types.Message, scenario.messageCount)
+ messages[0] = types.Message{Role: "system", Content: types.Content{Raw: "You are a helpful assistant."}}
+
+ for i := 1; i < scenario.messageCount; i++ {
+ role := "user"
+ if i%2 == 0 {
+ role = "assistant"
+ }
+ content := randomPrompt(scenario.messageLength)
+ messages[i] = types.Message{Role: role, Content: types.Content{Raw: content}}
+ }
+
+ pod := &types.PodMetrics{
+ Pod: &backend.Pod{
+ NamespacedName: k8stypes.NamespacedName{
+ Name: fmt.Sprintf("chat-pod-%d-%d", scenario.messageCount, scenario.messageLength),
+ },
+ },
+ }
+ pods := []types.Pod{pod}
+
+ req := &types.LLMRequest{
+ RequestId: uuid.NewString(),
+ TargetModel: "chat-model-stress",
+ Body: &types.LLMRequestBody{
+ ChatCompletions: &types.ChatCompletionsRequest{
+ Messages: messages,
+ },
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Benchmark the scoring operation
+ scores := plugin.Score(context.Background(), types.NewCycleState(), req, pods)
+ _ = scores // Use the result to prevent optimization
+
+ // Clean up state for next iteration
+ plugin.pluginState.Delete(req.RequestId)
+ }
+ })
+ }
+}
diff --git a/pkg/epp/requestcontrol/latencypredictor_helper.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go
similarity index 78%
rename from pkg/epp/requestcontrol/latencypredictor_helper.go
rename to pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go
index 8cd840391..ed86e5aa4 100644
--- a/pkg/epp/requestcontrol/latencypredictor_helper.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/latencypredictor_helper.go
@@ -11,7 +11,7 @@ distributed under the License.
*/
// Package requestcontrol contains helpers to decouple latency-predictor logic.
-package requestcontrol
+package slo_aware_router
import (
"context"
@@ -24,19 +24,24 @@ import (
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
)
-// RefreshLastSeenMetrics updates reqCtx.LastSeenMetrics from the latest scheduling result.
-func RefreshLastSeenMetrics(ctx context.Context, reqCtx *handlers.RequestContext) {
- if sr := reqCtx.SchedulingResult; sr != nil {
+const (
+ // Poisson sampling parameters for predictions
+ defaultSamplingMean = 100 // Mean interval between prediction samples (tokens)
+ maxSampledTokens = 20 // Maximum number of prediction samples per request
+)
+
+// RefreshLastSeenMetrics updates sloCtx.LastSeenMetrics from the latest scheduling result.
+func RefreshLastSeenMetrics(ctx context.Context, sloCtx *SLORequestContext) {
+ if sr := sloCtx.SchedulingResult; sr != nil {
if pr := sr.ProfileResults[sr.PrimaryProfileName]; pr != nil && pr.TargetPods != nil {
for profileName, profileResult := range sr.ProfileResults {
if profileResult != nil && profileResult.TargetPods != nil && len(profileResult.TargetPods) > 0 {
- reqCtx.LastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone()
+ sloCtx.LastSeenMetrics[profileName] = profileResult.TargetPods[0].GetMetrics().Clone()
}
}
}
@@ -99,32 +104,32 @@ func GetTargetPodForProfile(
return targetPod
}
-// GetMetricsForPrediction retrieves the latest metrics for prediction from reqCtx.LastSeenMetrics.
-func GetLatestMetricsForProfile(ctx context.Context, reqCtx *handlers.RequestContext, profileName string) (*backendmetrics.MetricsState, error) {
- if len(reqCtx.LastSeenMetrics) == 0 {
+// GetMetricsForPrediction retrieves the latest metrics for prediction from sloCtx.LastSeenMetrics.
+func GetLatestMetricsForProfile(ctx context.Context, sloCtx *SLORequestContext, profileName string) (*backendmetrics.MetricsState, error) {
+ if len(sloCtx.LastSeenMetrics) == 0 {
return nil, fmt.Errorf("no last seen metrics available for prediction")
}
// Use the primary profile's metrics for prediction
- if metrics, exists := reqCtx.LastSeenMetrics[profileName]; exists {
+ if metrics, exists := sloCtx.LastSeenMetrics[profileName]; exists {
return metrics, nil
}
log.FromContext(ctx).V(logutil.DEBUG).Info("No metrics found for profile, trying primary profile", "profile_name", profileName)
- primaryProfileName := reqCtx.SchedulingResult.PrimaryProfileName
- if metrics, exists := reqCtx.LastSeenMetrics[primaryProfileName]; exists {
+ primaryProfileName := sloCtx.SchedulingResult.PrimaryProfileName
+ if metrics, exists := sloCtx.LastSeenMetrics[primaryProfileName]; exists {
return metrics, nil
}
return nil, fmt.Errorf("no metrics found for primary profile %s", primaryProfileName)
}
-// ProcessHeader refreshes metrics, applies TTFT prediction, updates reqCtx.PredictedTTFT and timestamp.
+// ProcessHeader refreshes metrics, applies TTFT prediction, updates sloCtx.PredictedTTFT and timestamp.
func ProcessHeaderForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
- reqCtx *handlers.RequestContext,
+ sloCtx *SLORequestContext,
) error {
logger := log.FromContext(ctx)
@@ -133,18 +138,18 @@ func ProcessHeaderForLatencyPrediction(
// Build prediction request
//check if prefill profile name is set, if not use primary profile name
- m, err := GetLatestMetricsForProfile(ctx, reqCtx, "prefill")
+ m, err := GetLatestMetricsForProfile(ctx, sloCtx, "prefill")
if err != nil {
logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err)
return err
}
- targetPod := GetTargetPodForProfile(ctx, reqCtx.SchedulingResult, "prefill")
- prefix_cache_score := GetPrefixCacheScoreForPod(ctx, reqCtx.SchedulingResult, targetPod, "prefill")
+ targetPod := GetTargetPodForProfile(ctx, sloCtx.SchedulingResult, "prefill")
+ prefix_cache_score := GetPrefixCacheScoreForPod(ctx, sloCtx.SchedulingResult, targetPod, "prefill")
in := latencypredictor.PredictionRequest{
KVCachePercentage: m.KVCacheUsagePercent,
- InputTokenLength: len(strings.Fields(reqCtx.SchedulingRequest.Prompt)),
+ InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)),
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningQueueSize,
NumTokensGenerated: 0,
@@ -157,55 +162,55 @@ func ProcessHeaderForLatencyPrediction(
dur := time.Since(start)
if err != nil {
logger.V(logutil.DEBUG).Error(err, "header TTFT predict failed", "duration_ms", dur.Milliseconds())
- reqCtx.PredictedTTFT = 0
+ sloCtx.PredictedTTFT = 0
} else if p == nil {
logger.V(logutil.DEBUG).Info("header TTFT predict nil", "duration_ms", dur.Milliseconds())
- reqCtx.PredictedTTFT = 0
+ sloCtx.PredictedTTFT = 0
} else {
logger.V(logutil.DEBUG).Info("header TTFT succeeded", "value_ms", p.TTFT, "duration_ms", dur.Milliseconds())
- metrics.RecordRequestTTFTPredictionDuration(ctx, reqCtx.TargetModelName, reqCtx.IncomingModelName, dur.Seconds())
+ metrics.RecordRequestTTFTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds())
- reqCtx.PredictedTTFT = p.TTFT
+ sloCtx.PredictedTTFT = p.TTFT
}
// Advance timestamp for first token reference
- reqCtx.LastTokenTimestamp = time.Now()
- RefreshLastSeenMetrics(ctx, reqCtx)
+ sloCtx.LastTokenTimestamp = time.Now()
+ RefreshLastSeenMetrics(ctx, sloCtx)
return err
}
-// ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates reqCtx, and advances timestamp.
+// ProcessFirstToken records actual TTFT, trains, predicts first TPOT, updates sloCtx, and advances timestamp.
func ProcessFirstTokenForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
- reqCtx *handlers.RequestContext,
+ sloCtx *SLORequestContext,
now time.Time,
) {
logger := log.FromContext(ctx)
// Initialize sampler
- if reqCtx.TokenSampler == nil {
- requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
- reqCtx.TokenSampler = requtil.NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens)
- logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", reqCtx.TokenSampler.GetNextSampleToken())
+ if sloCtx.TokenSampler == nil {
+ requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey]
+ sloCtx.TokenSampler = requtil.NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens)
+ logger.V(logutil.DEBUG).Info("Initialized token sampler for first token", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken())
}
// Actual TTFT
- reqCtx.TTFT = float64(now.Sub(reqCtx.RequestReceivedTimestamp).Milliseconds())
- reqCtx.GeneratedTokenCount = 1
- m, err := GetLatestMetricsForProfile(ctx, reqCtx, "prefill")
+ sloCtx.TTFT = float64(now.Sub(sloCtx.RequestReceivedTimestamp).Milliseconds())
+ sloCtx.GeneratedTokenCount = 1
+ m, err := GetLatestMetricsForProfile(ctx, sloCtx, "prefill")
if err != nil {
logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err)
return
}
- targetPod := GetTargetPodForProfile(ctx, reqCtx.SchedulingResult, "prefill")
- prefix_cache_score := GetPrefixCacheScoreForPod(ctx, reqCtx.SchedulingResult, targetPod, "prefill")
+ targetPod := GetTargetPodForProfile(ctx, sloCtx.SchedulingResult, "prefill")
+ prefix_cache_score := GetPrefixCacheScoreForPod(ctx, sloCtx.SchedulingResult, targetPod, "prefill")
// Train TTFT
entry := latencypredictor.TrainingEntry{
KVCachePercentage: m.KVCacheUsagePercent,
- InputTokenLength: len(strings.Fields(reqCtx.SchedulingRequest.Prompt)),
- ActualTTFT: reqCtx.TTFT,
+ InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)),
+ ActualTTFT: sloCtx.TTFT,
ActualTPOT: 0,
Timestamp: now,
NumRequestWaiting: m.WaitingQueueSize,
@@ -216,7 +221,7 @@ func ProcessFirstTokenForLatencyPrediction(
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TTFT training failed")
}
- m, err = GetLatestMetricsForProfile(ctx, reqCtx, reqCtx.SchedulingResult.PrimaryProfileName)
+ m, err = GetLatestMetricsForProfile(ctx, sloCtx, sloCtx.SchedulingResult.PrimaryProfileName)
if err != nil {
logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics",
"error", err)
@@ -226,10 +231,10 @@ func ProcessFirstTokenForLatencyPrediction(
// Predict first TPOT
in := latencypredictor.PredictionRequest{
KVCachePercentage: m.KVCacheUsagePercent,
- InputTokenLength: len(strings.Fields(reqCtx.SchedulingRequest.Prompt)),
+ InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)),
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningQueueSize,
- NumTokensGenerated: reqCtx.GeneratedTokenCount,
+ NumTokensGenerated: sloCtx.GeneratedTokenCount,
PrefixCacheScore: 0,
}
start := time.Now()
@@ -237,48 +242,48 @@ func ProcessFirstTokenForLatencyPrediction(
dur := time.Since(start)
if err != nil || p == nil {
logger.V(logutil.DEBUG).Error(err, "first TPOT predict failed", "duration_ms", dur.Milliseconds())
- reqCtx.PredictedTPOTObservations = append(reqCtx.PredictedTPOTObservations, 0)
- reqCtx.AvgPredictedTPOT = calculateRunningAverage(reqCtx.AvgPredictedTPOT, 0, len(reqCtx.PredictedTPOTObservations))
+ sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0)
+ sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations))
} else {
logger.V(logutil.DEBUG).Info("first TPOT succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds())
- reqCtx.PredictedTPOTObservations = append(reqCtx.PredictedTPOTObservations, p.TPOT)
- reqCtx.AvgPredictedTPOT = calculateRunningAverage(reqCtx.AvgPredictedTPOT, p.TPOT, len(reqCtx.PredictedTPOTObservations))
+ sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT)
+ sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations))
}
- metrics.RecordRequestTPOTPredictionDuration(ctx, reqCtx.TargetModelName, reqCtx.IncomingModelName, dur.Seconds())
+ metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds())
// Advance timestamp
- reqCtx.LastTokenTimestamp = now
+ sloCtx.LastTokenTimestamp = now
// Refresh metrics
- RefreshLastSeenMetrics(ctx, reqCtx)
+ RefreshLastSeenMetrics(ctx, sloCtx)
}
-// ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates reqCtx, and advances timestamp.
+// ProcessToken records actual inter-token latency, trains, predicts sampled TPOT, updates sloCtx, and advances timestamp.
func ProcessTokenForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
- reqCtx *handlers.RequestContext,
+ sloCtx *SLORequestContext,
now time.Time,
) {
logger := log.FromContext(ctx)
// Initialize sampler if not yet
- if reqCtx.TokenSampler == nil {
- requestID := reqCtx.Request.Headers[requtil.RequestIdHeaderKey]
- reqCtx.TokenSampler = requtil.NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens)
- logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", reqCtx.TokenSampler.GetNextSampleToken())
+ if sloCtx.TokenSampler == nil {
+ requestID := sloCtx.SchedulingRequest.Headers[requtil.RequestIdHeaderKey]
+ sloCtx.TokenSampler = requtil.NewTokenSampler(requestID, defaultSamplingMean, maxSampledTokens)
+ logger.V(logutil.DEBUG).Info("Initialized token sampler for subsequent tokens", "request_id", requestID, "next_prediction_token", sloCtx.TokenSampler.GetNextSampleToken())
}
// Inter-token latency
- latencyMs := float64(now.Sub(reqCtx.LastTokenTimestamp).Milliseconds())
- reqCtx.GeneratedTokenCount++
+ latencyMs := float64(now.Sub(sloCtx.LastTokenTimestamp).Milliseconds())
+ sloCtx.GeneratedTokenCount++
//log the inter-token latency for predicted samples
- if reqCtx.GeneratedTokenCount == 2 || reqCtx.TokenSampler.ShouldPredict(reqCtx.GeneratedTokenCount) { //tricky logic, since next sample token is always +1 from current token
- reqCtx.TPOTObservations = append(reqCtx.TPOTObservations, latencyMs)
- reqCtx.AvgTPOT = calculateRunningAverage(reqCtx.AvgTPOT, latencyMs, len(reqCtx.TPOTObservations))
+ if sloCtx.GeneratedTokenCount == 2 || sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) { //tricky logic, since next sample token is always +1 from current token
+ sloCtx.TPOTObservations = append(sloCtx.TPOTObservations, latencyMs)
+ sloCtx.AvgTPOT = calculateRunningAverage(sloCtx.AvgTPOT, latencyMs, len(sloCtx.TPOTObservations))
}
- m, err := GetLatestMetricsForProfile(ctx, reqCtx, reqCtx.SchedulingResult.PrimaryProfileName)
+ m, err := GetLatestMetricsForProfile(ctx, sloCtx, sloCtx.SchedulingResult.PrimaryProfileName)
if err != nil {
logger.V(logutil.DEBUG).Info("Skipping first TPOT prediction due to missing metrics",
"error", err)
@@ -287,13 +292,13 @@ func ProcessTokenForLatencyPrediction(
// Record actual TPOT
entry := latencypredictor.TrainingEntry{
KVCachePercentage: m.KVCacheUsagePercent,
- InputTokenLength: len(strings.Fields(reqCtx.SchedulingRequest.Prompt)),
+ InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)),
ActualTTFT: 0,
ActualTPOT: latencyMs,
Timestamp: now,
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningQueueSize,
- NumTokensGenerated: reqCtx.GeneratedTokenCount - 1,
+ NumTokensGenerated: sloCtx.GeneratedTokenCount - 1,
PrefixCacheScore: 0, // TPOT does not use prefix cache score
}
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
@@ -301,13 +306,13 @@ func ProcessTokenForLatencyPrediction(
}
// Sampled predict
- if reqCtx.TokenSampler.ShouldPredict(reqCtx.GeneratedTokenCount) {
+ if sloCtx.TokenSampler.ShouldPredict(sloCtx.GeneratedTokenCount) {
in := latencypredictor.PredictionRequest{
KVCachePercentage: m.KVCacheUsagePercent,
- InputTokenLength: len(strings.Fields(reqCtx.SchedulingRequest.Prompt)),
+ InputTokenLength: len(strings.Fields(sloCtx.SchedulingRequest.Body.Completions.Prompt)),
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningQueueSize,
- NumTokensGenerated: reqCtx.GeneratedTokenCount,
+ NumTokensGenerated: sloCtx.GeneratedTokenCount,
PrefixCacheScore: 0, // TPOT does not use prefix cache score
}
start := time.Now()
@@ -315,22 +320,22 @@ func ProcessTokenForLatencyPrediction(
dur := time.Since(start)
if err != nil || p == nil {
logger.V(logutil.DEBUG).Error(err, "TPOT predict failed", "duration_ms", dur.Milliseconds())
- reqCtx.PredictedTPOTObservations = append(reqCtx.PredictedTPOTObservations, 0)
- reqCtx.AvgPredictedTPOT = calculateRunningAverage(reqCtx.AvgPredictedTPOT, 0, len(reqCtx.PredictedTPOTObservations))
+ sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, 0)
+ sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, 0, len(sloCtx.PredictedTPOTObservations))
} else {
logger.V(logutil.DEBUG).Info("TPOT predict succeeded", "value_ms", p.TPOT, "duration_ms", dur.Milliseconds())
- reqCtx.PredictedTPOTObservations = append(reqCtx.PredictedTPOTObservations, p.TPOT)
- reqCtx.AvgPredictedTPOT = calculateRunningAverage(reqCtx.AvgPredictedTPOT, p.TPOT, len(reqCtx.PredictedTPOTObservations))
+ sloCtx.PredictedTPOTObservations = append(sloCtx.PredictedTPOTObservations, p.TPOT)
+ sloCtx.AvgPredictedTPOT = calculateRunningAverage(sloCtx.AvgPredictedTPOT, p.TPOT, len(sloCtx.PredictedTPOTObservations))
}
- metrics.RecordRequestTPOTPredictionDuration(ctx, reqCtx.TargetModelName, reqCtx.IncomingModelName, dur.Seconds())
+ metrics.RecordRequestTPOTPredictionDuration(ctx, sloCtx.SchedulingRequest.TargetModel, sloCtx.IncomingModelName, dur.Seconds())
- reqCtx.TokenSampler.RecordPrediction(reqCtx.GeneratedTokenCount)
+ sloCtx.TokenSampler.RecordPrediction(sloCtx.GeneratedTokenCount)
}
// Advance timestamp
- reqCtx.LastTokenTimestamp = now
+ sloCtx.LastTokenTimestamp = now
// Refresh metrics
- RefreshLastSeenMetrics(ctx, reqCtx)
+ RefreshLastSeenMetrics(ctx, sloCtx)
}
// PredictWithMetrics predicts TTFT or TPOT based on provided metrics state and token count.
@@ -488,19 +493,19 @@ func BulkPredictWithMetrics(
}
// Fixed DebugPrintRawScores for map[string]map[Pod]float64 structure
-func DebugPrintRawScores(ctx context.Context, reqCtx *handlers.RequestContext) {
+func DebugPrintRawScores(ctx context.Context, sloCtx *SLORequestContext) {
logger := log.FromContext(ctx)
- if reqCtx.SchedulingResult == nil || reqCtx.SchedulingResult.AllProfileRunResults == nil {
+ if sloCtx.SchedulingResult == nil || sloCtx.SchedulingResult.AllProfileRunResults == nil {
logger.V(logutil.DEBUG).Info("No raw scheduling results available for debug")
return
}
logger.V(logutil.DEBUG).Info("=== RAW SCHEDULING RESULTS DEBUG START ===",
- "total_profiles", len(reqCtx.SchedulingResult.AllProfileRunResults))
+ "total_profiles", len(sloCtx.SchedulingResult.AllProfileRunResults))
// Print raw results for all profiles
- for profileName, profileResult := range reqCtx.SchedulingResult.AllProfileRunResults {
+ for profileName, profileResult := range sloCtx.SchedulingResult.AllProfileRunResults {
if profileResult == nil {
logger.V(logutil.DEBUG).Info("Profile result is nil", "profile", profileName)
continue
@@ -652,3 +657,14 @@ func GetPrefixCacheScoreForPod(
"profile", targetProfile)
return 0.0
}
+
+// calculateRunningAverage calculates the running average efficiently
+func calculateRunningAverage(currentAvg float64, newValue float64, count int) float64 {
+ if count == 0 {
+ return 0
+ }
+ if count == 1 {
+ return newValue
+ }
+ return currentAvg + (newValue-currentAvg)/float64(count)
+}
diff --git a/pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/plugin.go
similarity index 88%
rename from pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go
rename to pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/plugin.go
index bdf6a1aae..bff535722 100644
--- a/pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/plugin.go
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
-package scorer
+package slo_aware_router
import (
"context"
@@ -29,10 +29,8 @@ import (
"sigs.k8s.io/controller-runtime/pkg/log"
"k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/latencypredictorasync"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
@@ -55,9 +53,9 @@ const (
)
const (
- SLOScorerPluginType = "slo-scorer"
- MinScore = 0
- MaxScore = 100
+ SLOAwareRouterPluginType = "slo-aware-routing"
+ MinScore = 0
+ MaxScore = 100
)
var SLOBufferFactor = func() float64 {
@@ -163,50 +161,52 @@ type PodPredictionResult struct {
PrefixCacheScore float64 // Prefix cache score for the pod
}
-type SLOScorer struct {
- tn plugins.TypedName
- predictor latencypredictor.PredictorInterface
- datastore datastore.Datastore
- headroomStrategy HeadroomStrategy
+type SLOAwareRouter struct {
+ tn plugins.TypedName
+ latencypredictor latencypredictor.PredictorInterface
+ runningRequestLists map[types.NamespacedName]*RequestPriorityQueue
+ sloContextStore map[string]*SLORequestContext
+ headroomStrategy HeadroomStrategy
}
-func (s *SLOScorer) Dependencies() []plugins.TypedName {
+func (s *SLOAwareRouter) Dependencies() []plugins.TypedName {
return []plugins.TypedName{
{Type: "prefix-cache-scorer", Name: "prefix-cache-scorer"},
}
}
-var _ framework.Scorer = &SLOScorer{}
+var _ framework.Scorer = &SLOAwareRouter{}
-func NewSLOScorer(predictor latencypredictor.PredictorInterface, datastore datastore.Datastore, strategy HeadroomStrategy) *SLOScorer {
- return &SLOScorer{
- tn: plugins.TypedName{Type: SLOScorerPluginType, Name: SLOScorerPluginType},
- predictor: predictor,
- datastore: datastore,
- headroomStrategy: strategy,
+func NewSLOAwareRouter(latencypredictor latencypredictor.PredictorInterface, strategy HeadroomStrategy) *SLOAwareRouter {
+ return &SLOAwareRouter{
+ tn: plugins.TypedName{Type: SLOAwareRouterPluginType, Name: SLOAwareRouterPluginType},
+ latencypredictor: latencypredictor,
+ runningRequestLists: make(map[types.NamespacedName]*RequestPriorityQueue),
+ sloContextStore: make(map[string]*SLORequestContext),
+ headroomStrategy: strategy,
}
}
-func (s *SLOScorer) TypedName() plugins.TypedName {
+func (s *SLOAwareRouter) TypedName() plugins.TypedName {
return s.tn
}
-func (s *SLOScorer) WithName(name string) *SLOScorer {
+func (s *SLOAwareRouter) WithName(name string) *SLOAwareRouter {
s.tn.Name = name
return s
}
// SetHeadroomStrategy allows runtime configuration of headroom selection strategy
-func (s *SLOScorer) SetHeadroomStrategy(strategy HeadroomStrategy) {
+func (s *SLOAwareRouter) SetHeadroomStrategy(strategy HeadroomStrategy) {
s.headroomStrategy = strategy
}
// GetHeadroomStrategy returns the current headroom selection strategy
-func (s *SLOScorer) GetHeadroomStrategy() HeadroomStrategy {
+func (s *SLOAwareRouter) GetHeadroomStrategy() HeadroomStrategy {
return s.headroomStrategy
}
-func (s *SLOScorer) epsilonGreedyAffinityGate(
+func (s *SLOAwareRouter) epsilonGreedyAffinityGate(
ctx context.Context,
candidates []PodPredictionResult,
r *rand.Rand,
@@ -239,10 +239,10 @@ func (s *SLOScorer) epsilonGreedyAffinityGate(
return eligible, true
}
-func (s *SLOScorer) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 {
+func (s *SLOAwareRouter) Score(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, pods []schedulingtypes.Pod) map[schedulingtypes.Pod]float64 {
logger := log.FromContext(ctx)
- if s.predictor == nil {
- logger.V(logutil.DEBUG).Info("SLOScorer: no predictor configured, returning nil scores")
+ if s.latencypredictor == nil {
+ logger.V(logutil.DEBUG).Info("SLOAwareRouter: no predictor configured, returning nil scores")
return nil
}
@@ -345,7 +345,7 @@ func (s *SLOScorer) Score(ctx context.Context, state *schedulingtypes.CycleState
// selectFromPositiveHeadroomPods selects a pod from positive headroom pods using headroom strategy
// Updated to incorporate TTFTHeadroom with a configurable blend vs TPOT headroom.
-func (s *SLOScorer) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
+func (s *SLOAwareRouter) selectFromPositiveHeadroomPods(ctx context.Context, posHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
logger := log.FromContext(ctx)
if len(posHeadroomPods) == 1 {
@@ -458,6 +458,7 @@ func (s *SLOScorer) selectFromPositiveHeadroomPods(ctx context.Context, posHeadr
// If no pod was selected (shouldn't happen), fallback to first pod
if selectedPod == nil {
selectedPod = candidates[0].Pod
+ selectedPod = posHeadroomPods[0].Pod
}
return selectedPod
@@ -465,7 +466,7 @@ func (s *SLOScorer) selectFromPositiveHeadroomPods(ctx context.Context, posHeadr
// selectFromNegativeHeadroomPods selects a pod from negative headroom pods using hierarchical TTFT/TPOT logic
// Modified to strictly prefer pods with 0 running requests
-func (s *SLOScorer) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
+func (s *SLOAwareRouter) selectFromNegativeHeadroomPods(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
logger := log.FromContext(ctx)
if len(negHeadroomPods) == 1 {
@@ -500,7 +501,7 @@ func (s *SLOScorer) selectFromNegativeHeadroomPods(ctx context.Context, negHeadr
}
// selectFromNegativeHeadroomPodsInternal handles the actual selection logic for negative headroom pods
-func (s *SLOScorer) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
+func (s *SLOAwareRouter) selectFromNegativeHeadroomPodsInternal(ctx context.Context, negHeadroomPods []PodPredictionResult, r *rand.Rand) schedulingtypes.Pod {
if len(negHeadroomPods) == 1 {
return negHeadroomPods[0].Pod
}
@@ -543,7 +544,7 @@ func (s *SLOScorer) selectFromNegativeHeadroomPodsInternal(ctx context.Context,
// weightPodsByBlendedDeficit applies blended weighting using TTFT and TPOT deficits.
// Lower blended deficit => higher weight.
-func (ps *SLOScorer) weightPodsByBlendedDeficit(
+func (ps *SLOAwareRouter) weightPodsByBlendedDeficit(
ctx context.Context,
pods []PodPredictionResult,
choices *[]Choice,
@@ -643,7 +644,7 @@ func (ps *SLOScorer) weightPodsByBlendedDeficit(
}
}
-func (s *SLOScorer) handleNegativeHeadroomPodsHierarchical(
+func (s *SLOAwareRouter) handleNegativeHeadroomPodsHierarchical(
ctx context.Context,
negHeadroomPods []PodPredictionResult,
choices *[]Choice,
@@ -700,7 +701,7 @@ func (s *SLOScorer) handleNegativeHeadroomPodsHierarchical(
}
// generatePredictions creates prediction results for all candidate pods
-func (s *SLOScorer) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) []PodPredictionResult {
+func (s *SLOAwareRouter) generatePredictions(ctx context.Context, state *schedulingtypes.CycleState, request *schedulingtypes.LLMRequest, candidatePods []schedulingtypes.Pod) []PodPredictionResult {
logger := log.FromContext(ctx)
predictions := make([]PodPredictionResult, 0, len(candidatePods))
@@ -712,10 +713,8 @@ func (s *SLOScorer) generatePredictions(ctx context.Context, state *schedulingty
// Get prefix cache score for the pod
prefixCacheScore := s.getPrefixCacheScoreForPod(ctx, state, pod)
- // TODO update the request in the datastore request tracker
-
// Generate prediction
- prediction, err := requestcontrol.PredictWithMetrics(ctx, s.predictor, pod.GetMetrics(), request.Prompt, 1, prefixCacheScore)
+ prediction, err := PredictWithMetrics(ctx, s.latencypredictor, pod.GetMetrics(), request.Body.Completions.Prompt, 1, prefixCacheScore)
if err != nil {
logger.V(logutil.DEBUG).Info("Skipping pod due to prediction error", "pod", pod.GetPod().String(), "error", err)
predResult.Error = err
@@ -754,31 +753,31 @@ func (s *SLOScorer) generatePredictions(ctx context.Context, state *schedulingty
return predictions
}
-func (s *SLOScorer) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 {
+func (s *SLOAwareRouter) getPodMinTPOTSLO(pod schedulingtypes.Pod) float64 {
podName := types.NamespacedName{
Name: pod.GetPod().NamespacedName.Name,
Namespace: pod.GetPod().NamespacedName.Namespace,
}
- if runningReqs, err := s.datastore.PodGetRunningRequests(podName); err == nil && runningReqs != nil {
+ if runningReqs, ok := s.runningRequestLists[podName]; ok && runningReqs.GetSize() > 0 {
if topReq := runningReqs.Peek(); topReq != nil {
return topReq.TPOT
}
}
- return 0
+ return 0 // no running requests or no TPOT SLOs
}
-func (s *SLOScorer) getPodRunningRequestCount(pod schedulingtypes.Pod) int {
+func (s *SLOAwareRouter) getPodRunningRequestCount(pod schedulingtypes.Pod) int {
podName := types.NamespacedName{
Name: pod.GetPod().NamespacedName.Name,
Namespace: pod.GetPod().NamespacedName.Namespace,
}
- if runningReqs, err := s.datastore.PodGetRequestCount(podName); err == nil {
- return runningReqs
+ if runningReqs, ok := s.runningRequestLists[podName]; ok {
+ return runningReqs.GetSize()
}
- return 0
+ return 0 // no running requests
}
-func (s *SLOScorer) validatePrediction(
+func (s *SLOAwareRouter) validatePrediction(
pred *latencypredictor.PredictionResponse,
req *schedulingtypes.LLMRequest,
podMinTPOTSLO float64,
@@ -803,7 +802,7 @@ func (s *SLOScorer) validatePrediction(
return
}
-func (s *SLOScorer) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 {
+func (s *SLOAwareRouter) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 {
log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String())
plugintype := prefix.PrefixCachePluginType
pluginname := prefix.PrefixCachePluginType
@@ -838,7 +837,7 @@ func (s *SLOScorer) getPrefixCacheScoreForPod(ctx context.Context, cycleState *s
}
// updateRequestContextWithPredictions updates the request context with prediction data
-func (s *SLOScorer) updateRequestContextWithPredictions(request *schedulingtypes.LLMRequest, predictions []PodPredictionResult) {
+func (s *SLOAwareRouter) updateRequestContextWithPredictions(request *schedulingtypes.LLMRequest, predictions []PodPredictionResult) {
for _, pred := range predictions {
if pred.Error == nil {
podKey := pred.Pod.GetPod().String()
diff --git a/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go
new file mode 100644
index 000000000..b1c66d5a9
--- /dev/null
+++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/requestcontrol_hooks.go
@@ -0,0 +1,218 @@
+package slo_aware_router
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/go-logr/logr"
+ "github.com/google/uuid"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+
+ "k8s.io/apimachinery/pkg/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
+ backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
+ schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
+)
+
+var _ requestcontrol.PreRequest = &SLOAwareRouter{}
+var _ requestcontrol.ResponseReceived = &SLOAwareRouter{}
+var _ requestcontrol.ResponseStreaming = &SLOAwareRouter{}
+var _ requestcontrol.ResponseComplete = &SLOAwareRouter{}
+
+type SLORequestContext struct {
+ SchedulingRequest schedulingtypes.LLMRequest
+ TargetPod *backend.Pod
+ SchedulingResult *schedulingtypes.SchedulingResult
+ LastSeenMetrics map[string]*backendmetrics.MetricsState
+ LastTokenTimestamp time.Time
+ RequestReceivedTimestamp time.Time
+ GeneratedTokenCount int
+ IncomingModelName string
+ TTFT float64
+ PredictedTTFT float64
+ AvgTPOT float64
+ AvgPredictedTPOT float64
+ TokenSampler *requtil.TokenSampler
+ TPOTObservations []float64
+ PredictedTPOTObservations []float64
+}
+
+func NewSLORequestContext(request *schedulingtypes.LLMRequest) *SLORequestContext {
+ return &SLORequestContext{
+ SchedulingRequest: *request,
+ LastSeenMetrics: make(map[string]*backendmetrics.MetricsState),
+ }
+}
+
+func (s *SLOAwareRouter) getSLOContextForRequest(request *schedulingtypes.LLMRequest) (*SLORequestContext, error) {
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ if ctx, exists := s.sloContextStore[id]; exists {
+ return ctx, nil
+ }
+ return nil, fmt.Errorf("SLO context not found for request ID: %s", id)
+}
+
+func (s *SLOAwareRouter) setSLOContextForRequest(request *schedulingtypes.LLMRequest, ctx *SLORequestContext) {
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ s.sloContextStore[id] = ctx
+}
+
+func (s *SLOAwareRouter) deleteSLOContextForRequest(request *schedulingtypes.LLMRequest) {
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ delete(s.sloContextStore, id)
+}
+
+// --- RequestControl Hooks ---
+
+func (t *SLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) {
+ logger := log.FromContext(ctx)
+
+ if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 {
+ logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PreRequest because no scheduling result was provided.")
+ return
+ }
+
+ targetPod := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName].TargetPods[0].GetPod()
+
+ podName := types.NamespacedName{
+ Name: targetPod.NamespacedName.Name,
+ Namespace: targetPod.NamespacedName.Namespace,
+ }
+
+ logger.V(logutil.DEBUG).Info("request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey], "podName", podName)
+ if request.Headers[requtil.RequestIdHeaderKey] == "" {
+ request.Headers[requtil.RequestIdHeaderKey] = uuid.New().String()
+ logger.V(logutil.DEBUG).Info("Generated new request ID for SLO tracking", "requestID", request.Headers[requtil.RequestIdHeaderKey])
+ logger.V(logutil.DEBUG).Info("request headers for SLO tracking", "requestHeaders", request.Headers)
+ }
+
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ podRequestList, ok := t.runningRequestLists[podName]
+ if !ok {
+ podRequestList = NewRequestPriorityQueue()
+ t.runningRequestLists[podName] = podRequestList
+ }
+
+ added := podRequestList.Add(id, request.AvgTPOTSLO)
+ if !added {
+ logger.V(logutil.DEBUG).Info("SLOAwareRouter: Item already exists in queue", "podName", podName, "requestID", id)
+ }
+
+ // Set up SLO request context
+ sloCtx := NewSLORequestContext(request)
+ sloCtx.TargetPod = targetPod
+ sloCtx.SchedulingResult = schedulingResult
+ RefreshLastSeenMetrics(ctx, sloCtx)
+ t.setSLOContextForRequest(request, sloCtx)
+}
+
+func (t *SLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
+ logger := log.FromContext(ctx)
+ id := request.Headers[requtil.RequestIdHeaderKey]
+
+ sloCtx, err := t.getSLOContextForRequest(request)
+ if err != nil {
+ logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to get SLO context for request", "requestID", id)
+ return
+ }
+
+ if !t.CheckPredictor(logger, targetPod) {
+ return
+ }
+
+ if err := ProcessHeaderForLatencyPrediction(ctx, t.latencypredictor, sloCtx); err != nil {
+ logger.V(logutil.DEBUG).Error(err, "ProcessHeader in latencypredictor failed")
+ }
+
+}
+
+func (t *SLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
+ logger := log.FromContext(ctx)
+ if !t.CheckPredictor(logger, pod) {
+ return
+ }
+
+ now := time.Now()
+ sloCtx, err := t.getSLOContextForRequest(request)
+ if err != nil {
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseStreaming: Failed to get SLO context for request", "requestID", id)
+ return
+ }
+
+ if sloCtx.TTFT == 0 {
+ ProcessFirstTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now)
+ } else {
+ ProcessTokenForLatencyPrediction(ctx, t.latencypredictor, sloCtx, now)
+ }
+
+}
+
+func (t *SLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *backend.Pod) {
+ logger := log.FromContext(ctx)
+ targetPod := pod
+ if !t.CheckPredictor(logger, targetPod) {
+ return
+ }
+
+ sloCtx, err := t.getSLOContextForRequest(request)
+ if err != nil {
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter.ResponseComplete: Failed to get SLO context for request", "requestID", id)
+ return
+ }
+
+ if sloCtx.TTFT > 0 {
+ logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTTFT", sloCtx.TTFT, "avgPredictedTTFT", sloCtx.PredictedTTFT)
+ metrics.RecordRequestTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT/1000)
+ metrics.RecordRequestPredictedTTFT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.PredictedTTFT/1000)
+ if sloCtx.SchedulingRequest.TTFTSLO > 0 {
+ metrics.RecordRequestTTFTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.TTFT, sloCtx.SchedulingRequest.TTFTSLO)
+ }
+ }
+
+ if sloCtx.AvgTPOT > 0 {
+ logger.V(logutil.DEBUG).Info("Averages calculated", "avgActualTPOT", sloCtx.AvgTPOT, "avgPredictedTPOT", sloCtx.AvgPredictedTPOT)
+ metrics.RecordRequestTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT/1000)
+ metrics.RecordRequestPredictedTPOT(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgPredictedTPOT/1000)
+ if sloCtx.SchedulingRequest.AvgTPOTSLO > 0 {
+ metrics.RecordRequestTPOTWithSLO(ctx, sloCtx.IncomingModelName, request.TargetModel, sloCtx.AvgTPOT, sloCtx.SchedulingRequest.AvgTPOTSLO)
+ }
+ }
+ logger.V(logutil.DEBUG).Info("SLO Aware Routing Mode", "PredictorBasedScheduling", request.PredictorBasedScheduling)
+
+ podName := types.NamespacedName{
+ Name: targetPod.NamespacedName.Name,
+ Namespace: targetPod.NamespacedName.Namespace,
+ }
+
+ id := request.Headers[requtil.RequestIdHeaderKey]
+ podRequestList, ok := t.runningRequestLists[podName]
+ if !ok {
+ err := fmt.Errorf("no running request list found for pod %s", podName.String())
+ logger.V(logutil.DEBUG).Error(err, "SLOAwareRouter: Failed to remove request from queue", "requestID", id)
+ }
+
+ _, removed := podRequestList.Remove(id)
+ if !removed {
+ logger.V(logutil.DEBUG).Info("SLOAwareRouter: Item not found in queue", "podName", podName, "requestID", id)
+ }
+ t.deleteSLOContextForRequest(request)
+}
+
+func (t *SLOAwareRouter) CheckPredictor(logger logr.Logger, targetPod *backend.Pod) bool {
+ if targetPod == nil {
+ logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PostResponse because no target pod was provided.")
+ return false
+ }
+ if t.latencypredictor == nil {
+ logger.V(logutil.DEBUG).Info("SLOAwareRouter: Skipping PostResponse because predictor missing")
+ return false
+ }
+ return true
+}
diff --git a/pkg/epp/datalayer/running_request_queue.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go
similarity index 99%
rename from pkg/epp/datalayer/running_request_queue.go
rename to pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go
index 29bef911a..1199be641 100644
--- a/pkg/epp/datalayer/running_request_queue.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue.go
@@ -1,4 +1,4 @@
-package datalayer
+package slo_aware_router
import (
"container/heap"
diff --git a/pkg/epp/datalayer/running_request_queue_test.go b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go
similarity index 99%
rename from pkg/epp/datalayer/running_request_queue_test.go
rename to pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go
index bac82106d..a8eba5fe1 100644
--- a/pkg/epp/datalayer/running_request_queue_test.go
+++ b/pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/running_request_queue_test.go
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
-package datalayer
+package slo_aware_router
import (
"fmt"
diff --git a/pkg/epp/scheduling/framework/plugins/picker/common.go b/pkg/epp/scheduling/framework/plugins/picker/common.go
index 4bbc300da..c8655840f 100644
--- a/pkg/epp/scheduling/framework/plugins/picker/common.go
+++ b/pkg/epp/scheduling/framework/plugins/picker/common.go
@@ -16,6 +16,13 @@ limitations under the License.
package picker
+import (
+ "math/rand/v2"
+ "time"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+)
+
const (
DefaultMaxNumOfEndpoints = 1 // common default to all pickers
)
@@ -24,3 +31,14 @@ const (
type pickerParameters struct {
MaxNumOfEndpoints int `json:"maxNumOfEndpoints"`
}
+
+func shuffleScoredPods(scoredPods []*types.ScoredPod) {
+ // Rand package is not safe for concurrent use, so we create a new instance.
+ // Source: https://pkg.go.dev/math/rand/v2#pkg-overview
+ randomGenerator := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0))
+
+ // Shuffle in-place
+ randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
+ scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
+ })
+}
diff --git a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go
index 325f735fa..33e99bd06 100644
--- a/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go
+++ b/pkg/epp/scheduling/framework/plugins/picker/max_score_picker.go
@@ -20,9 +20,7 @@ import (
"context"
"encoding/json"
"fmt"
- "math/rand"
"slices"
- "time"
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -85,15 +83,8 @@ func (p *MaxScorePicker) Pick(ctx context.Context, cycleState *types.CycleState,
log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates sorted by max score", "max-num-of-endpoints", p.maxNumOfEndpoints,
"num-of-candidates", len(scoredPods), "scored-pods", scoredPods)
- // TODO: merge this with the logic in RandomPicker
- // Rand package is not safe for concurrent use, so we create a new instance.
- // Source: https://pkg.go.dev/math/rand#pkg-overview
- randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))
-
// Shuffle in-place - needed for random tie break when scores are equal
- randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
- scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
- })
+ shuffleScoredPods(scoredPods)
slices.SortStableFunc(scoredPods, func(i, j *types.ScoredPod) int { // highest score first
if i.Score > j.Score {
diff --git a/pkg/epp/scheduling/framework/plugins/picker/picker_test.go b/pkg/epp/scheduling/framework/plugins/picker/picker_test.go
index 741a49d59..022328efd 100644
--- a/pkg/epp/scheduling/framework/plugins/picker/picker_test.go
+++ b/pkg/epp/scheduling/framework/plugins/picker/picker_test.go
@@ -18,6 +18,7 @@ package picker
import (
"context"
+ "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -138,8 +139,8 @@ func TestPickMaxScorePicker(t *testing.T) {
func TestPickWeightedRandomPicker(t *testing.T) {
const (
- testIterations = 1000
- tolerance = 0.2 // 20% tolerance in [0,1] range
+ testIterations = 10000
+ tolerance = 0.05 // Verify within tolerance ±5%
)
pod1 := &types.PodMetrics{Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}
@@ -197,14 +198,14 @@ func TestPickWeightedRandomPicker(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
picker := NewWeightedRandomPicker(test.maxPods)
- selectionCounts := make(map[string]int)
- // Calculate expected probabilities based on scores
+ // Summarize the total score of all pods
totalScore := 0.0
for _, pod := range test.input {
totalScore += pod.Score
}
+ // Calculate expected probabilities based on scores
expectedProbabilities := make(map[string]float64)
for _, pod := range test.input {
podName := pod.GetPod().NamespacedName.Name
@@ -216,20 +217,19 @@ func TestPickWeightedRandomPicker(t *testing.T) {
}
// Initialize selection counters for each pod
+ selectionCounts := make(map[string]int)
for _, pod := range test.input {
podName := pod.GetPod().NamespacedName.Name
selectionCounts[podName] = 0
}
// Run multiple iterations to gather statistical data
- for i := 0; i < testIterations; i++ {
+ for range testIterations {
result := picker.Pick(context.Background(), types.NewCycleState(), test.input)
// Count selections for probability analysis
- if len(result.TargetPods) > 0 {
- selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name
- selectionCounts[selectedPodName]++
- }
+ selectedPodName := result.TargetPods[0].GetPod().NamespacedName.Name
+ selectionCounts[selectedPodName]++
}
// Verify probability distribution
@@ -237,11 +237,7 @@ func TestPickWeightedRandomPicker(t *testing.T) {
actualCount := selectionCounts[podName]
actualProb := float64(actualCount) / float64(testIterations)
- toleranceValue := expectedProb * tolerance
- lowerBound := expectedProb - toleranceValue
- upperBound := expectedProb + toleranceValue
-
- if actualProb < lowerBound || actualProb > upperBound {
+ if math.Abs(actualProb-expectedProb) > tolerance {
t.Errorf("Pod %s: expected probability %.3f ±%.1f%%, got %.3f (count: %d/%d)",
podName, expectedProb, tolerance*100, actualProb, actualCount, testIterations)
} else {
diff --git a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go
index 87a1747fc..10ad68469 100644
--- a/pkg/epp/scheduling/framework/plugins/picker/random_picker.go
+++ b/pkg/epp/scheduling/framework/plugins/picker/random_picker.go
@@ -20,8 +20,6 @@ import (
"context"
"encoding/json"
"fmt"
- "math/rand"
- "time"
"sigs.k8s.io/controller-runtime/pkg/log"
@@ -84,15 +82,8 @@ func (p *RandomPicker) Pick(ctx context.Context, _ *types.CycleState, scoredPods
log.FromContext(ctx).V(logutil.DEBUG).Info("Selecting pods from candidates randomly", "max-num-of-endpoints", p.maxNumOfEndpoints,
"num-of-candidates", len(scoredPods), "scored-pods", scoredPods)
- // TODO: merge this with the logic in MaxScorePicker
- // Rand package is not safe for concurrent use, so we create a new instance.
- // Source: https://pkg.go.dev/math/rand#pkg-overview
- randomGenerator := rand.New(rand.NewSource(time.Now().UnixNano()))
-
// Shuffle in-place
- randomGenerator.Shuffle(len(scoredPods), func(i, j int) {
- scoredPods[i], scoredPods[j] = scoredPods[j], scoredPods[i]
- })
+ shuffleScoredPods(scoredPods)
// if we have enough pods to return keep only the relevant subset
if p.maxNumOfEndpoints < len(scoredPods) {
diff --git a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go
index 48d982cd8..6db2c23e8 100644
--- a/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go
+++ b/pkg/epp/scheduling/framework/plugins/scorer/kvcache_utilization.go
@@ -65,9 +65,8 @@ func (s *KVCacheUtilizationScorer) WithName(name string) *KVCacheUtilizationScor
}
// Score returns the scoring result for the given list of pods based on context.
-func (s *KVCacheUtilizationScorer) Score(_ context.Context, _ *types.CycleState, req *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
+func (s *KVCacheUtilizationScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
scores := make(map[types.Pod]float64, len(pods))
-
for _, pod := range pods {
scores[pod] = 1 - pod.GetMetrics().KVCacheUsagePercent
}
diff --git a/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go b/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go
index d3cbad4b4..fc5b8f7c4 100644
--- a/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go
+++ b/pkg/epp/scheduling/framework/plugins/scorer/lora_affinity.go
@@ -65,7 +65,6 @@ func (s *LoraAffinityScorer) WithName(name string) *LoraAffinityScorer {
}
func (s *LoraAffinityScorer) Score(_ context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
-
scores := make(map[types.Pod]float64, len(pods))
if request.PredictorBasedScheduling {
diff --git a/pkg/epp/scheduling/framework/plugins/scorer/queue.go b/pkg/epp/scheduling/framework/plugins/scorer/queue.go
index 9f9fd763a..0db645283 100644
--- a/pkg/epp/scheduling/framework/plugins/scorer/queue.go
+++ b/pkg/epp/scheduling/framework/plugins/scorer/queue.go
@@ -67,8 +67,7 @@ func (s *QueueScorer) WithName(name string) *QueueScorer {
}
// Score returns the scoring result for the given list of pods based on context.
-func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, req *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
-
+func (s *QueueScorer) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
minQueueSize := math.MaxInt
maxQueueSize := math.MinInt
diff --git a/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go b/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go
index bf36d7782..2836755d4 100644
--- a/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go
+++ b/pkg/epp/scheduling/framework/plugins/test/filter/request_header_based_filter.go
@@ -73,7 +73,7 @@ func (f *HeaderBasedTestingFilter) Filter(_ context.Context, _ *types.CycleState
podAddressMap := make(map[string]types.Pod, len(pods))
for _, pod := range pods {
- podAddressMap[pod.GetPod().Address] = pod
+ podAddressMap[pod.GetPod().GetIPAddress()] = pod
}
endpoints := strings.Split(headerValue, ",")
diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go
index 7bb6fc653..f3833f5d7 100644
--- a/pkg/epp/scheduling/framework/scheduler_profile.go
+++ b/pkg/epp/scheduling/framework/scheduler_profile.go
@@ -191,6 +191,14 @@ func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types.
}
+ for pod, score := range scores {
+ logger.V(logutil.DEBUG).Info("Pod score",
+ "scorer_type", scorer.TypedName().Type,
+ "scorer_name", scorer.TypedName().Name,
+ "pod_namespace", pod.GetPod().NamespacedName.Namespace,
+ "pod_name", pod.GetPod().NamespacedName.Name,
+ "score", score)
+ }
for pod, score := range scores {
logger.V(logutil.DEBUG).Info("Pod score",
"scorer_type", scorer.TypedName().Type,
diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go
index 056723dbf..15c5b6658 100644
--- a/pkg/epp/scheduling/types/types.go
+++ b/pkg/epp/scheduling/types/types.go
@@ -17,20 +17,25 @@ limitations under the License.
package types
import (
+ "encoding/json"
+ "errors"
"fmt"
+ "strings"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
)
+const nilString = "
+
+### A Note on Interacting with Multiple API Versions
+
+During the zero-downtime migration, both `v1alpha2` and `v1` CRDs will be installed on your cluster. This can create ambiguity when using `kubectl` to query for `InferencePool` resources. To ensure you are interacting with the correct version, you **must** use the full resource name:
+
+* **For v1alpha2**: `kubectl get inferencepools.inference.networking.x-k8s.io`
+* **For v1**: `kubectl get inferencepools.inference.networking.k8s.io`
+
+The `v1` API also provides a convenient short name, `infpool`, which can be used to query `v1` resources specifically:
+
+```bash
+kubectl get infpool
+```
+
+This guide will use these full names or the short name for `v1` to avoid ambiguity.
+
+***
+
+### Stage 1: Side-by-side v1 Deployment
+
+In this stage, you will deploy the new `v1` `InferencePool` stack alongside the existing `v1alpha2` stack. This allows for a safe, gradual migration.
+
+After finishing all the steps in this stage, you’ll have the following infrastructure shown in the following diagram
+
+
+
+**1. Install v1 CRDs**
+
+```bash
+RELEASE=v1.0.0
+kubectl apply -f [https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/$RELEASE/config/crd/bases/inference.networking.x-k8s.io_inferenceobjectives.yaml](https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/$RELEASE/config/crd/bases/inference.networking.x-k8s.io_inferenceobjectives.yaml)
+```
+
+**2. Install the v1 `InferencePool`**
+
+Use Helm to install a new `v1` `InferencePool` with a distinct release name (e.g., `vllm-llama3-8b-instruct-ga`).
+
+```bash
+helm install vllm-llama3-8b-instruct-ga \
+ --set inferencePool.modelServers.matchLabels.app=