Skip to content

Commit c40cc15

Browse files
usizeguygir
authored andcommitted
feat: Add a scoring plugin to distribute new groups evenly (llm-d#357)
1 parent 555cdda commit c40cc15

File tree

5 files changed

+817
-0
lines changed

5 files changed

+817
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Sample EPP configuration for running without P/D
2+
# with small hash block size for simulation purposes
3+
apiVersion: inference.networking.x-k8s.io/v1alpha1
4+
kind: EndpointPickerConfig
5+
plugins:
6+
- type: prefix-cache-scorer
7+
parameters:
8+
hashBlockSize: 5
9+
maxPrefixBlocksToMatch: 256
10+
lruCapacityPerServer: 31250
11+
- type: no-hit-lru-scorer
12+
parameters:
13+
lruSize: 2048
14+
- type: decode-filter
15+
- type: max-score-picker
16+
- type: single-profile-handler
17+
schedulingProfiles:
18+
- name: default
19+
plugins:
20+
- pluginRef: decode-filter
21+
- pluginRef: max-score-picker
22+
- pluginRef: prefix-cache-scorer
23+
weight: 2
24+
- pluginRef: no-hit-lru-scorer
25+
weight: 1

docs/architecture.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,57 @@ used for the same session.
364364

365365
---
366366

367+
#### NoHitLRUScorer
368+
369+
Scores pods based on least recently used (LRU) ordering for cold requests (requests with no KV cache hits).
370+
This helps evenly distribute cache growth across pods, since cold requests result in new KV blocks being created.
371+
372+
The scorer integrates with a prefix cache plugin to determine if a request has cache hits:
373+
- For cold requests (no cache hits): Ranks pods by LRU order, with never-used or least recently used pods
374+
receiving higher scores (up to 1.0) and most recently used pods receiving lower scores (approaching 0.0)
375+
- For warm requests (cache hits): Returns neutral scores (0.5) for all pods to avoid interfering with
376+
cache locality optimization
377+
378+
The LRU tracking is specific to cold requests only - pods are added to the LRU cache when they serve
379+
a cold request, not when they serve requests with cache hits.
380+
381+
- **Type**: `no-hit-lru-scorer`
382+
- **Parameters**:
383+
- `prefixPluginName` (optional): The name of the prefix cache plugin to read state from. Defaults to `prefix-cache-scorer`.
384+
- `lruSize` (optional): The maximum number of pods to track in the LRU cache. Defaults to 1024.
385+
386+
Example configuration:
387+
388+
```yaml
389+
plugins:
390+
- type: prefix-cache-scorer
391+
parameters:
392+
hashBlockSize: 5
393+
maxPrefixBlocksToMatch: 256
394+
lruCapacityPerServer: 31250
395+
- type: no-hit-lru-scorer
396+
parameters:
397+
lruSize: 2048
398+
- type: decode-filter
399+
- type: max-score-picker
400+
- type: single-profile-handler
401+
schedulingProfiles:
402+
- name: default
403+
plugins:
404+
- pluginRef: decode-filter
405+
- pluginRef: max-score-picker
406+
- pluginRef: prefix-cache-scorer
407+
weight: 2
408+
- pluginRef: no-hit-lru-scorer
409+
weight: 1
410+
```
411+
412+
**Note:** This scorer is designed to work alongside a prefix cache scorer (such as `prefix-cache-scorer` or
413+
`precise-prefix-cache-scorer`). If no prefix cache state is available, all requests are treated as cold.
414+
When integrating with a prefix-cache scorer, the prefix-cache scorer should be defined first in the scheduling profile.
415+
416+
---
417+
367418
### Sample Disaggregated Prefill/Decode Configuration
368419

369420
The following is an example of what a configuration for disaggregated Prefill/Decode might look like:

pkg/plugins/register.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ func RegisterAllPlugins() {
2020
plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory)
2121
plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
2222
plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
23+
plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
2324
}

pkg/plugins/scorer/no_hit_lru.go

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
package scorer
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
8+
lru "github.com/hashicorp/golang-lru/v2"
9+
"sigs.k8s.io/controller-runtime/pkg/log"
10+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
11+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
12+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
13+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
14+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
15+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
16+
)
17+
18+
const (
19+
// NoHitLRUType is the type of the NoHitLRU scorer
20+
NoHitLRUType = "no-hit-lru-scorer"
21+
22+
// defaultLRUSize is the maximum number of pods we'll consider in the cache
23+
defaultLRUSize = 1024
24+
)
25+
26+
// compile-time type assertions
27+
var _ framework.Scorer = &NoHitLRU{}
28+
var _ requestcontrol.PreRequest = &NoHitLRU{}
29+
30+
// NoHitLRUParameters defines the parameters for the NoHitLRU scorer.
31+
type NoHitLRUParameters struct {
32+
// PrefixPluginName defines the name of the prefix cache plugin to read state from.
33+
// Defaults to "prefix-cache-scorer".
34+
PrefixPluginName string `json:"prefixPluginName"`
35+
36+
// LRUSize defines the maximum number of pods to track in the LRU cache.
37+
LRUSize int `json:"lruSize"`
38+
}
39+
40+
// coldRequestState tracks whether a request triggered a KV cache hit
41+
// when the cache is missed, isCold is true.
42+
type coldRequestState struct {
43+
isCold bool
44+
}
45+
46+
// Clone implements the plugins.StateData interface
47+
func (c *coldRequestState) Clone() plugins.StateData {
48+
return &coldRequestState{isCold: c.isCold}
49+
}
50+
51+
// NoHitLRUFactory defines the factory function for the NoHitLRU
52+
func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
53+
parameters := NoHitLRUParameters{}
54+
if rawParameters != nil {
55+
if err := json.Unmarshal(rawParameters, &parameters); err != nil {
56+
return nil, fmt.Errorf("failed to parse the parameters of the '%s' scorer - %w", NoHitLRUType, err)
57+
}
58+
}
59+
60+
if parameters.PrefixPluginName == "" {
61+
parameters.PrefixPluginName = prefix.PrefixCachePluginType
62+
}
63+
64+
// Note: We don't enforce that the prefix plugin exists here
65+
// The scorer will gracefully handle missing prefix cache state as an optimization
66+
67+
return NewNoHitLRU(handle.Context(), &parameters).WithName(name), nil
68+
}
69+
70+
// NewNoHitLRU creates a new NoHitLRU scorer
71+
func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU {
72+
prefixPluginName := prefix.PrefixCachePluginType
73+
lruSize := defaultLRUSize
74+
75+
if params != nil {
76+
if params.PrefixPluginName != "" {
77+
prefixPluginName = params.PrefixPluginName
78+
}
79+
if params.LRUSize > 0 {
80+
lruSize = params.LRUSize
81+
}
82+
}
83+
84+
lruCache, err := lru.New[string, struct{}](lruSize)
85+
if err != nil {
86+
log.FromContext(ctx).Error(err, fmt.Sprintf("failed to initialize NoHitLRU scorer: could not create LRU cache with size %d: %v", lruSize, err))
87+
return nil
88+
}
89+
90+
return &NoHitLRU{
91+
typedName: plugins.TypedName{Type: NoHitLRUType},
92+
lruCache: lruCache,
93+
prefixPluginName: prefixPluginName,
94+
pluginState: plugins.NewPluginState(ctx),
95+
}
96+
}
97+
98+
// NoHitLRU scorer that favors pods that were least recently used for cold requests.
99+
// This can help evenly distribute cache growth, since cold requests result in more
100+
// new KV blocks.
101+
type NoHitLRU struct {
102+
typedName plugins.TypedName
103+
lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order)
104+
prefixPluginName string
105+
pluginState *plugins.PluginState
106+
}
107+
108+
// TypedName returns the typed name of the plugin.
109+
func (s *NoHitLRU) TypedName() plugins.TypedName {
110+
return s.typedName
111+
}
112+
113+
// WithName sets the name of the plugin.
114+
func (s *NoHitLRU) WithName(name string) *NoHitLRU {
115+
s.typedName.Name = name
116+
return s
117+
}
118+
119+
// isColdRequest determines if a request is cold by reading the prefix cache state.
120+
// Returns true if no prefix cache hits were found, or if prefix cache state is unavailable.
121+
func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleState) bool {
122+
logger := log.FromContext(ctx).V(logutil.DEBUG)
123+
124+
// Read prefix cache state to determine if this is a cold request
125+
// This is treated as an optimization - if the state isn't available, we assume cold request
126+
prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginName))
127+
128+
if err != nil {
129+
logger.Info("No prefix cache state found, treating as cold request for LRU optimization", "error", err)
130+
return true
131+
}
132+
133+
// Check if this is a cold request (no prefix cache hits)
134+
return len(prefixState.PrefixCacheServers) == 0
135+
}
136+
137+
// scoreNeutral returns neutral scores (0.5) for all pods.
138+
// Used when a request has cache hits and LRU optimization should not apply.
139+
func (s *NoHitLRU) scoreNeutral(pods []types.Pod) map[types.Pod]float64 {
140+
scoredPods := make(map[types.Pod]float64, len(pods))
141+
for _, pod := range pods {
142+
scoredPods[pod] = 0.5
143+
}
144+
return scoredPods
145+
}
146+
147+
// getLRUPositions returns a map of pod names to their LRU position.
148+
// Position 0 represents the oldest (least recently used) entry.
149+
func (s *NoHitLRU) getLRUPositions() map[string]int {
150+
// Get all keys from LRU cache in order (oldest first)
151+
// https://pkg.go.dev/github.com/hashicorp/golang-lru/v2#Cache.Keys
152+
lruKeys := s.lruCache.Keys()
153+
154+
lruPosition := make(map[string]int, len(lruKeys))
155+
for i, key := range lruKeys {
156+
lruPosition[key] = i
157+
}
158+
return lruPosition
159+
}
160+
161+
// partitionPodsByUsage separates pods into those that have received cold requests
162+
// (usedPods) and those that have never received cold requests (neverUsedPods).
163+
func (s *NoHitLRU) partitionPodsByUsage(pods []types.Pod, lruPosition map[string]int) (usedPods, neverUsedPods []types.Pod) {
164+
for _, pod := range pods {
165+
podName := pod.GetPod().NamespacedName.String()
166+
if _, exists := lruPosition[podName]; exists {
167+
usedPods = append(usedPods, pod)
168+
} else {
169+
neverUsedPods = append(neverUsedPods, pod)
170+
}
171+
}
172+
return usedPods, neverUsedPods
173+
}
174+
175+
// scoreNeverUsedPods assigns scores to pods that have never received a cold request.
176+
// The first never-used pod gets the highest score (1.0), with subsequent pods
177+
// receiving progressively lower scores.
178+
func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[types.Pod]float64, neverUsedPods []types.Pod, totalPods int) {
179+
// Avoid possibility of dividing by zero.
180+
if totalPods <= 1 {
181+
return
182+
}
183+
for i, pod := range neverUsedPods {
184+
score := 1.0 - float64(i)/float64(totalPods-1)
185+
scoredPods[pod] = score
186+
}
187+
}
188+
189+
// scoreUsedPods assigns scores to pods based on their LRU position.
190+
// Pods that were least recently used for cold requests receive higher scores.
191+
func (s *NoHitLRU) scoreUsedPods(scoredPods map[types.Pod]float64, usedPods []types.Pod, lruPosition map[string]int, neverUsedCount, totalPods int) {
192+
// Avoid possibility of dividing by zero.
193+
if totalPods <= 1 {
194+
return
195+
}
196+
for _, pod := range usedPods {
197+
podName := pod.GetPod().NamespacedName.String()
198+
lruPos := lruPosition[podName]
199+
// LRU keys are oldest to newest so rank 0 = oldest
200+
// The never used pod count is added to the rank so that
201+
// a never-used pod will always have the highest score.
202+
rank := neverUsedCount + lruPos
203+
score := 1.0 - float64(rank)/float64(totalPods-1)
204+
if score < 0 {
205+
score = 0
206+
}
207+
scoredPods[pod] = score
208+
}
209+
}
210+
211+
// scoreColdRequestByLRU scores pods based on their LRU position for cold requests.
212+
// Pods that have never received a cold request get the highest scores.
213+
// Among previously used pods, least recently used ones get higher scores.
214+
func (s *NoHitLRU) scoreColdRequestByLRU(pods []types.Pod) map[types.Pod]float64 {
215+
scoredPods := make(map[types.Pod]float64, len(pods))
216+
totalPods := len(pods)
217+
218+
// Avoid possibility of dividing by zero.
219+
if totalPods == 1 {
220+
scoredPods[pods[0]] = 1.0
221+
return scoredPods
222+
}
223+
224+
lruPosition := s.getLRUPositions()
225+
usedPods, neverUsedPods := s.partitionPodsByUsage(pods, lruPosition)
226+
227+
s.scoreNeverUsedPods(scoredPods, neverUsedPods, totalPods)
228+
s.scoreUsedPods(scoredPods, usedPods, lruPosition, len(neverUsedPods), totalPods)
229+
230+
return scoredPods
231+
}
232+
233+
// Score scores the given pods based on LRU for cold requests.
234+
// For cache hits, returns neutral scores (0.5) for all pods.
235+
// For cache misses, ranks pods by their LRU order.
236+
// - LRU ordering is with respect to when a pod last received a cold request.
237+
// - Least recently used (or never used) pods get highest score (1.0)
238+
// - Most recently used pods get lowest score (approaching 0.0)
239+
func (s *NoHitLRU) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
240+
logger := log.FromContext(ctx).V(logutil.DEBUG)
241+
242+
isCold := s.isColdRequest(ctx, cycleState)
243+
244+
// Store the cold request state in plugin state for PreRequest to use
245+
coldState := &coldRequestState{isCold: isCold}
246+
s.pluginState.Write(request.RequestId, plugins.StateKey(s.typedName.String()), coldState)
247+
248+
if !isCold {
249+
logger.Info("Cache hit detected, returning neutral scores")
250+
return s.scoreNeutral(pods)
251+
}
252+
253+
logger.Info("Cold request detected, scoring pods by LRU")
254+
return s.scoreColdRequestByLRU(pods)
255+
}
256+
257+
// PreRequest is called before a request is sent to the target pod.
258+
// For cold requests, it updates the LRU cache to track which pods have been used recently.
259+
func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult, _ int) {
260+
logger := log.FromContext(ctx).V(logutil.DEBUG)
261+
262+
if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 {
263+
logger.Info("No scheduling result available")
264+
return
265+
}
266+
267+
// Read the cold request state we stored in Score
268+
coldState, err := plugins.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugins.StateKey(s.typedName.String()))
269+
// After fetching the cold state, drop it from the plugin state immediately (otherwise it will hang around until it becomes stale).
270+
s.pluginState.Delete(request.RequestId)
271+
272+
if err != nil {
273+
logger.Info("No cold request state found, treating as non-cold request", "error", err)
274+
return
275+
}
276+
277+
if !coldState.isCold {
278+
logger.Info("Not a cold request, skipping LRU update")
279+
return
280+
}
281+
282+
// Get the primary profile's target pod
283+
primaryProfile := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
284+
if primaryProfile == nil || len(primaryProfile.TargetPods) == 0 {
285+
logger.Info("No target pod in primary profile")
286+
return
287+
}
288+
289+
targetPod := primaryProfile.TargetPods[0]
290+
podName := targetPod.GetPod().NamespacedName.String()
291+
292+
// Move the pod to the front of the LRU.
293+
var present struct{} // dummy value
294+
s.lruCache.Add(podName, present)
295+
296+
logger.Info("Updated LRU cache for cold request", "pod", podName, "requestId", request.RequestId)
297+
}

0 commit comments

Comments
 (0)