@@ -2,6 +2,8 @@ package programaware
22
33import (
44 "fmt"
5+ "slices"
6+ "sync"
57 "time"
68
79 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/flowcontrol"
@@ -37,6 +39,10 @@ type ScoringStrategy interface {
3739
3840 // OnCompleted is called when a response finishes with actual token usage.
3941 OnCompleted (metrics * ProgramMetrics , promptTokens , completionTokens int64 )
42+
43+ // OnPicked is called after Pick() selects a queue, allowing strategies
44+ // to update internal state (e.g., round-robin cursor).
45+ OnPicked (programID string )
4046}
4147
4248// newStrategy constructs a ScoringStrategy from the plugin config.
@@ -55,8 +61,10 @@ func newStrategy(cfg Config) (ScoringStrategy, error) {
5561 decayFactor : floatOr (cfg .ServiceDecayFactor , defaultServiceDecayFactor ),
5662 halfLifeSeconds : floatOr (cfg .ServiceHalfLifeSeconds , 0 ),
5763 }, nil
64+ case "rr" :
65+ return & RRStrategy {}, nil
5866 default :
59- return nil , fmt .Errorf ("unknown scoring strategy %q: valid values are \" drr\" , \" service\" " , cfg .Strategy )
67+ return nil , fmt .Errorf ("unknown scoring strategy %q: valid values are \" drr\" , \" service\" , \" rr \" " , cfg .Strategy )
6068 }
6169}
6270
@@ -177,6 +185,9 @@ func (s *DRRStrategy) Score(normalized []float64) float64 {
177185 s .weightHeadWait * normalized [drrDimHeadWait ]
178186}
179187
188+ // OnPicked is a no-op for DRR (cursor not needed).
189+ func (s * DRRStrategy ) OnPicked (_ string ) {}
190+
180191// OnCompleted deducts actual token usage from the deficit counter.
181192func (s * DRRStrategy ) OnCompleted (metrics * ProgramMetrics , promptTokens , completionTokens int64 ) {
182193 if metrics == nil {
@@ -277,6 +288,9 @@ func (s *ServiceStrategy) Score(normalized []float64) float64 {
277288 s .weightHeadWait * normalized [serviceDimHeadWait ]
278289}
279290
291+ // OnPicked is a no-op for Service (cursor not needed).
292+ func (s * ServiceStrategy ) OnPicked (_ string ) {}
293+
280294// OnCompleted accumulates the weighted token cost into the program's attained service.
281295func (s * ServiceStrategy ) OnCompleted (metrics * ProgramMetrics , promptTokens , completionTokens int64 ) {
282296 if metrics == nil {
@@ -285,3 +299,109 @@ func (s *ServiceStrategy) OnCompleted(metrics *ProgramMetrics, promptTokens, com
285299 cost := float64 (weightInputToken * promptTokens + weightOutputToken * completionTokens )
286300 metrics .AddService (cost )
287301}
302+
303+ // =============================================================================
304+ // RR (Round-Robin) Strategy
305+ // =============================================================================
306+
307+ // RR dimension indices.
308+ const (
309+ rrDimPosition = 0
310+ rrNumDimensions = 1
311+ )
312+
313+ // RRStrategy implements a simple round-robin scheduling strategy that matches
314+ // the upstream gateway-api-inference-extension round-robin fairness policy.
315+ //
316+ // It maintains a cursor (lastSelected) that tracks which program was last
317+ // dispatched. On each Pick() cycle, programs are sorted deterministically
318+ // and the one immediately after the cursor gets the highest score.
319+ // Empty queues are naturally skipped because Pick() only scores non-empty queues.
320+ type RRStrategy struct {
321+ mu sync.Mutex
322+ lastSelected string // program ID last picked
323+ cycleKeys []string // sorted program IDs collected during current cycle
324+ cycleActive bool // true once OnPickStart has been called for this cycle
325+ }
326+
327+ // Name returns "rr".
328+ func (s * RRStrategy ) Name () string { return "rr" }
329+
330+ // OnPickStart collects program IDs for deterministic ordering.
331+ // Called once per queue per Pick() cycle. On the first call of a new cycle,
332+ // resets the key list.
333+ func (s * RRStrategy ) OnPickStart (programID string , _ int , _ * ProgramMetrics ) {
334+ s .mu .Lock ()
335+ defer s .mu .Unlock ()
336+ if ! s .cycleActive {
337+ s .cycleKeys = s .cycleKeys [:0 ]
338+ s .cycleActive = true
339+ }
340+ s .cycleKeys = append (s .cycleKeys , programID )
341+ }
342+
343+ // NumDimensions returns 1 (position-based score).
344+ func (s * RRStrategy ) NumDimensions () int { return rrNumDimensions }
345+
346+ // CollectRaw computes a position-based score for the queue.
347+ // Queues closer to the cursor's "next" position get higher scores.
348+ func (s * RRStrategy ) CollectRaw (queue flowcontrol.FlowQueueAccessor , _ * ProgramMetrics ) []float64 {
349+ s .mu .Lock ()
350+ defer s .mu .Unlock ()
351+
352+ // Sort keys for deterministic ordering (same as upstream).
353+ keys := make ([]string , len (s .cycleKeys ))
354+ copy (keys , s .cycleKeys )
355+ slices .Sort (keys )
356+
357+ numFlows := len (keys )
358+ if numFlows == 0 {
359+ return []float64 {0 }
360+ }
361+
362+ // Find the start index (next after lastSelected), matching upstream logic.
363+ startIndex := 0
364+ if s .lastSelected != "" {
365+ if idx := slices .Index (keys , s .lastSelected ); idx != - 1 {
366+ startIndex = (idx + 1 ) % numFlows
367+ }
368+ }
369+
370+ // Find this queue's position in the sorted list.
371+ programID := queue .FlowKey ().ID
372+ queueIdx := slices .Index (keys , programID )
373+ if queueIdx == - 1 {
374+ return []float64 {0 }
375+ }
376+
377+ // Compute distance from startIndex (wrapping around).
378+ // Distance 0 = highest priority, distance numFlows-1 = lowest.
379+ distance := (queueIdx - startIndex + numFlows ) % numFlows
380+
381+ // Convert to score: closer to cursor's next position = higher score.
382+ score := float64 (numFlows - distance )
383+ return []float64 {score }
384+ }
385+
386+ // NormalizeDimension is a passthrough for RR — normalization is not meaningful
387+ // since the position-based scores already encode the correct ordering.
388+ func (s * RRStrategy ) NormalizeDimension (_ int , raw , _ , _ float64 ) float64 {
389+ return raw
390+ }
391+
392+ // Score returns the position score directly.
393+ func (s * RRStrategy ) Score (normalized []float64 ) float64 {
394+ return normalized [rrDimPosition ]
395+ }
396+
397+ // OnPicked updates the round-robin cursor and clears per-cycle state.
398+ func (s * RRStrategy ) OnPicked (programID string ) {
399+ s .mu .Lock ()
400+ defer s .mu .Unlock ()
401+ s .lastSelected = programID
402+ s .cycleKeys = s .cycleKeys [:0 ]
403+ s .cycleActive = false
404+ }
405+
406+ // OnCompleted is a no-op for round-robin (no token tracking needed).
407+ func (s * RRStrategy ) OnCompleted (_ * ProgramMetrics , _ , _ int64 ) {}
0 commit comments