Skip to content

Commit 9392bfa

Browse files
Add round-robin (RR) scheduling strategy with tests
Signed-off-by: Dasari Surya Sai Venkatesh <suryasai.venkatesh@gmail.com>
1 parent 63b36ab commit 9392bfa

File tree

3 files changed

+318
-6
lines changed

3 files changed

+318
-6
lines changed

pkg/plugins/program-aware/plugin.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const (
3131
// Config holds the JSON-decoded configuration for the plugin.
3232
type Config struct {
3333
// Strategy selects the fairness scoring algorithm used by Pick().
34-
// Valid values: "service" (default), "drr".
34+
// Valid values: "service" (default), "drr", "rr".
3535
//
3636
// "service" — attained service fairness: tracks time-decayed weighted tokens
3737
// consumed per program. Programs with lower attained service are
@@ -41,6 +41,10 @@ type Config struct {
4141
// Each round every active queue earns a token quantum; actual token
4242
// usage is deducted at response completion. Provides provably
4343
// proportional fairness independent of request rate or size.
44+
//
45+
// "rr" — Simple round-robin: cycles through program queues in sorted order,
46+
// skipping empty queues. Matches the upstream round-robin fairness
47+
// policy. No token or service tracking.
4448
Strategy string `json:"strategy"`
4549

4650
// --- DRR weights (only used when strategy == "drr") ---
@@ -246,6 +250,14 @@ func (p *ProgramAwarePlugin) Pick(_ context.Context, band flowcontrol.PriorityBa
246250
}
247251
}
248252

253+
// Notify the strategy that the Pick() cycle is complete.
254+
// When no queue was selected, empty string resets cursor (matches upstream).
255+
pickedID := ""
256+
if bestQueue != nil {
257+
pickedID = bestQueue.FlowKey().ID
258+
}
259+
strategy.OnPicked(pickedID)
260+
249261
// Record the selected item's enqueue time so PreRequest can compute
250262
// the actual flow-control queue wait time (enqueue → dispatch).
251263
if bestQueue != nil {

pkg/plugins/program-aware/strategy.go

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package programaware
22

33
import (
44
"fmt"
5+
"slices"
6+
"sync"
57
"time"
68

79
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/flowcontrol"
@@ -37,6 +39,10 @@ type ScoringStrategy interface {
3739

3840
// OnCompleted is called when a response finishes with actual token usage.
3941
OnCompleted(metrics *ProgramMetrics, promptTokens, completionTokens int64)
42+
43+
// OnPicked is called after Pick() selects a queue, allowing strategies
44+
// to update internal state (e.g., round-robin cursor).
45+
OnPicked(programID string)
4046
}
4147

4248
// newStrategy constructs a ScoringStrategy from the plugin config.
@@ -55,8 +61,10 @@ func newStrategy(cfg Config) (ScoringStrategy, error) {
5561
decayFactor: floatOr(cfg.ServiceDecayFactor, defaultServiceDecayFactor),
5662
halfLifeSeconds: floatOr(cfg.ServiceHalfLifeSeconds, 0),
5763
}, nil
64+
case "rr":
65+
return &RRStrategy{}, nil
5866
default:
59-
return nil, fmt.Errorf("unknown scoring strategy %q: valid values are \"drr\", \"service\"", cfg.Strategy)
67+
return nil, fmt.Errorf("unknown scoring strategy %q: valid values are \"drr\", \"service\", \"rr\"", cfg.Strategy)
6068
}
6169
}
6270

@@ -177,6 +185,9 @@ func (s *DRRStrategy) Score(normalized []float64) float64 {
177185
s.weightHeadWait*normalized[drrDimHeadWait]
178186
}
179187

188+
// OnPicked is a no-op for DRR (cursor not needed).
189+
func (s *DRRStrategy) OnPicked(_ string) {}
190+
180191
// OnCompleted deducts actual token usage from the deficit counter.
181192
func (s *DRRStrategy) OnCompleted(metrics *ProgramMetrics, promptTokens, completionTokens int64) {
182193
if metrics == nil {
@@ -277,6 +288,9 @@ func (s *ServiceStrategy) Score(normalized []float64) float64 {
277288
s.weightHeadWait*normalized[serviceDimHeadWait]
278289
}
279290

291+
// OnPicked is a no-op for Service (cursor not needed).
292+
func (s *ServiceStrategy) OnPicked(_ string) {}
293+
280294
// OnCompleted accumulates the weighted token cost into the program's attained service.
281295
func (s *ServiceStrategy) OnCompleted(metrics *ProgramMetrics, promptTokens, completionTokens int64) {
282296
if metrics == nil {
@@ -285,3 +299,109 @@ func (s *ServiceStrategy) OnCompleted(metrics *ProgramMetrics, promptTokens, com
285299
cost := float64(weightInputToken*promptTokens + weightOutputToken*completionTokens)
286300
metrics.AddService(cost)
287301
}
302+
303+
// =============================================================================
304+
// RR (Round-Robin) Strategy
305+
// =============================================================================
306+
307+
// RR dimension indices.
308+
const (
309+
rrDimPosition = 0
310+
rrNumDimensions = 1
311+
)
312+
313+
// RRStrategy implements a simple round-robin scheduling strategy that matches
314+
// the upstream gateway-api-inference-extension round-robin fairness policy.
315+
//
316+
// It maintains a cursor (lastSelected) that tracks which program was last
317+
// dispatched. On each Pick() cycle, programs are sorted deterministically
318+
// and the one immediately after the cursor gets the highest score.
319+
// Empty queues are naturally skipped because Pick() only scores non-empty queues.
320+
type RRStrategy struct {
321+
mu sync.Mutex
322+
lastSelected string // program ID last picked
323+
cycleKeys []string // sorted program IDs collected during current cycle
324+
cycleActive bool // true once OnPickStart has been called for this cycle
325+
}
326+
327+
// Name returns "rr".
328+
func (s *RRStrategy) Name() string { return "rr" }
329+
330+
// OnPickStart collects program IDs for deterministic ordering.
331+
// Called once per queue per Pick() cycle. On the first call of a new cycle,
332+
// resets the key list.
333+
func (s *RRStrategy) OnPickStart(programID string, _ int, _ *ProgramMetrics) {
334+
s.mu.Lock()
335+
defer s.mu.Unlock()
336+
if !s.cycleActive {
337+
s.cycleKeys = s.cycleKeys[:0]
338+
s.cycleActive = true
339+
}
340+
s.cycleKeys = append(s.cycleKeys, programID)
341+
}
342+
343+
// NumDimensions returns 1 (position-based score).
344+
func (s *RRStrategy) NumDimensions() int { return rrNumDimensions }
345+
346+
// CollectRaw computes a position-based score for the queue.
347+
// Queues closer to the cursor's "next" position get higher scores.
348+
func (s *RRStrategy) CollectRaw(queue flowcontrol.FlowQueueAccessor, _ *ProgramMetrics) []float64 {
349+
s.mu.Lock()
350+
defer s.mu.Unlock()
351+
352+
// Sort keys for deterministic ordering (same as upstream).
353+
keys := make([]string, len(s.cycleKeys))
354+
copy(keys, s.cycleKeys)
355+
slices.Sort(keys)
356+
357+
numFlows := len(keys)
358+
if numFlows == 0 {
359+
return []float64{0}
360+
}
361+
362+
// Find the start index (next after lastSelected), matching upstream logic.
363+
startIndex := 0
364+
if s.lastSelected != "" {
365+
if idx := slices.Index(keys, s.lastSelected); idx != -1 {
366+
startIndex = (idx + 1) % numFlows
367+
}
368+
}
369+
370+
// Find this queue's position in the sorted list.
371+
programID := queue.FlowKey().ID
372+
queueIdx := slices.Index(keys, programID)
373+
if queueIdx == -1 {
374+
return []float64{0}
375+
}
376+
377+
// Compute distance from startIndex (wrapping around).
378+
// Distance 0 = highest priority, distance numFlows-1 = lowest.
379+
distance := (queueIdx - startIndex + numFlows) % numFlows
380+
381+
// Convert to score: closer to cursor's next position = higher score.
382+
score := float64(numFlows - distance)
383+
return []float64{score}
384+
}
385+
386+
// NormalizeDimension is a passthrough for RR — normalization is not meaningful
387+
// since the position-based scores already encode the correct ordering.
388+
func (s *RRStrategy) NormalizeDimension(_ int, raw, _, _ float64) float64 {
389+
return raw
390+
}
391+
392+
// Score returns the position score directly.
393+
func (s *RRStrategy) Score(normalized []float64) float64 {
394+
return normalized[rrDimPosition]
395+
}
396+
397+
// OnPicked updates the round-robin cursor and clears per-cycle state.
398+
func (s *RRStrategy) OnPicked(programID string) {
399+
s.mu.Lock()
400+
defer s.mu.Unlock()
401+
s.lastSelected = programID
402+
s.cycleKeys = s.cycleKeys[:0]
403+
s.cycleActive = false
404+
}
405+
406+
// OnCompleted is a no-op for round-robin (no token tracking needed).
407+
func (s *RRStrategy) OnCompleted(_ *ProgramMetrics, _, _ int64) {}

0 commit comments

Comments
 (0)