Skip to content

Commit 78b19f3

Browse files
committed
proximity scorer
1 parent 9a55e4f commit 78b19f3

File tree

4 files changed

+156
-0
lines changed

4 files changed

+156
-0
lines changed

deploy/config/pd-epp-config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ kind: EndpointPickerConfig
44
plugins:
55
- type: prefill-header-handler
66
- type: prefix-cache-scorer
7+
- type: proximity-scorer
8+
parameters:
9+
sameNodeScore: 1.0
10+
defaultScore: 0.0
711
- type: prefill-filter
812
- type: decode-filter
913
- type: max-score-picker
@@ -15,6 +19,8 @@ schedulingProfiles:
1519
- pluginRef: max-score-picker
1620
- pluginRef: prefix-cache-scorer
1721
weight: 2
22+
- pluginRef: proximity-scorer
23+
weight: 1
1824
- name: decode
1925
plugins:
2026
- pluginRef: decode-filter

pkg/plugins/profile/pd_profile_handler.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,16 @@ const (
2828
defaultDecodeProfile = "decode"
2929
defaultPrefillProfile = "prefill"
3030
defaultPrefixPluginType = prefix.PrefixCachePluginType
31+
32+
// ProximityStateKey is the CycleState key for storing decode pod proximity information
33+
ProximityStateKey = "proximity-state"
3134
)
3235

36+
// ProximityState stores the decode pod's node name for proximity scoring
37+
type ProximityState struct {
38+
DecodeNodeName string
39+
}
40+
3341
type pdProfileHandlerParameters struct {
3442
Threshold int `json:"threshold"`
3543
DecodeProfile string `json:"decodeProfile"`
@@ -169,6 +177,21 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat
169177
}
170178

171179
metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode)
180+
181+
// Extract decode pod's node name and store in CycleState for proximity scoring
182+
decodePod := profileResults[h.decodeProfile].TargetPods[0]
183+
decodePodMetadata := decodePod.GetPod()
184+
nodeName := decodePodMetadata.NodeName
185+
186+
// Store decode pod node name in CycleState for proximity scorer
187+
proximityState := &ProximityState{
188+
DecodeNodeName: nodeName,
189+
}
190+
cycleState.Write(plugins.StateKey(ProximityStateKey), proximityState)
191+
192+
log.FromContext(ctx).V(logutil.DEBUG).Info("Stored decode pod node for proximity scoring",
193+
"nodeName", nodeName)
194+
172195
// run the prefill profile
173196
return map[string]*framework.SchedulerProfile{
174197
h.prefillProfile: profiles[h.prefillProfile],

pkg/plugins/register.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ func RegisterAllPlugins() {
2222
plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
2323
plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
2424
plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
25+
plugins.Register(scorer.ProximityScorerType, scorer.ProximityScorerFactory)
2526
}

pkg/plugins/scorer/proximity.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package scorer
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
8+
"sigs.k8s.io/controller-runtime/pkg/log"
9+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
10+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
11+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
12+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
13+
14+
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
15+
)
16+
17+
const (
18+
// ProximityScorerType is the type of the ProximityScorer
19+
ProximityScorerType = "proximity-scorer"
20+
)
21+
22+
// ProximityScorerParams defines configuration parameters for the proximity scorer
23+
type ProximityScorerParams struct {
24+
SameNodeScore float64 `json:"sameNodeScore"` // Score for same node (NVLink ~900GB/s)
25+
DefaultScore float64 `json:"defaultScore"` // Score for different node (InfiniBand ~400GB/s)
26+
}
27+
28+
// compile-time type assertion
29+
var _ framework.Scorer = &ProximityScorer{}
30+
31+
// ProximityScorerFactory defines the factory function for the ProximityScorer
32+
func ProximityScorerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
33+
params := ProximityScorerParams{
34+
SameNodeScore: 1.0,
35+
DefaultScore: 0.0,
36+
}
37+
38+
if rawParameters != nil {
39+
if err := json.Unmarshal(rawParameters, &params); err != nil {
40+
return nil, fmt.Errorf("failed to parse proximity scorer parameters: %w", err)
41+
}
42+
}
43+
44+
return NewProximityScorer(params).WithName(name), nil
45+
}
46+
47+
// NewProximityScorer creates a new ProximityScorer instance
48+
func NewProximityScorer(params ProximityScorerParams) *ProximityScorer {
49+
return &ProximityScorer{
50+
typedName: plugins.TypedName{Type: ProximityScorerType},
51+
params: params,
52+
}
53+
}
54+
55+
// ProximityScorer is a scorer that prefers prefill pods on the same node as the decode pod
56+
type ProximityScorer struct {
57+
typedName plugins.TypedName
58+
params ProximityScorerParams
59+
}
60+
61+
// TypedName returns the typed name of the plugin
62+
func (s *ProximityScorer) TypedName() plugins.TypedName {
63+
return s.typedName
64+
}
65+
66+
// WithName sets the name of the plugin
67+
func (s *ProximityScorer) WithName(name string) *ProximityScorer {
68+
s.typedName.Name = name
69+
return s
70+
}
71+
72+
// Category returns the scorer category (Affinity to prefer locality)
73+
func (s *ProximityScorer) Category() framework.ScorerCategory {
74+
return framework.Affinity
75+
}
76+
77+
// Score assigns scores to prefill pods based on their proximity to the decode pod
78+
func (s *ProximityScorer) Score(ctx context.Context, cycleState *types.CycleState,
79+
request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
80+
81+
scores := make(map[types.Pod]float64, len(pods))
82+
83+
// Read decode pod node name from CycleState
84+
proximityState, err := types.ReadCycleStateKey[*profile.ProximityState](
85+
cycleState,
86+
plugins.StateKey(profile.ProximityStateKey),
87+
)
88+
89+
if err != nil {
90+
// If decode pod node name not available, give all pods default score
91+
log.FromContext(ctx).V(logutil.DEBUG).Info(
92+
"Proximity state not found in CycleState, using default scores",
93+
"error", err,
94+
)
95+
for _, pod := range pods {
96+
scores[pod] = s.params.DefaultScore
97+
}
98+
return scores
99+
}
100+
101+
// Score each prefill pod based on proximity to decode pod
102+
for _, pod := range pods {
103+
metadata := pod.GetPod()
104+
105+
// Get prefill pod's node name
106+
prefillNodeName := metadata.NodeName
107+
108+
// Simple binary scoring: same node = 1.0, different node = 0.0
109+
score := s.params.DefaultScore
110+
if prefillNodeName != "" && prefillNodeName == proximityState.DecodeNodeName {
111+
score = s.params.SameNodeScore // Same node - NVLink (~900GB/s)
112+
}
113+
// else: Different node - InfiniBand (~400GB/s)
114+
115+
scores[pod] = score
116+
117+
log.FromContext(ctx).V(logutil.DEBUG).Info("Proximity score calculated",
118+
"pod", metadata.NamespacedName.String(),
119+
"prefillNode", prefillNodeName,
120+
"decodeNode", proximityState.DecodeNodeName,
121+
"score", score,
122+
)
123+
}
124+
125+
return scores
126+
}

0 commit comments

Comments
 (0)