Skip to content

Commit 38add61

Browse files
committed
Add P/D-aware SLO scheduling support
- Add PDPredictionRequestBuilder to populate PodType from llm-d.ai/role labels - Add pd-slo-aware-scorer plugin wrapping slo_aware_router with P/D builder - Register pd-slo-aware-scorer in plugin registry - Add example EPP config for P/D SLO-aware scheduling (pd-slo-epp-config.yaml) - Add comprehensive guide on P/D SLO scheduling (docs/pd-slo-aware-scheduling.md) Enables separate latency prediction models for prefill vs decode workloads.
1 parent 1f9ddb4 commit 38add61

File tree

9 files changed

+541
-22
lines changed

9 files changed

+541
-22
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
apiVersion: inference.networking.x-k8s.io/v1alpha1
2+
kind: EndpointPickerConfig
3+
plugins:
4+
- type: prefill-filter
5+
- type: decode-filter
6+
- type: prefill-header-handler
7+
- type: prefix-cache-scorer
8+
parameters:
9+
blockSize: 5
10+
- type: pd-slo-aware-scorer
11+
parameters:
12+
sloBufferFactor: 1.0
13+
headroomSelectionStrategy: "least"
14+
- type: max-score-picker
15+
- type: pd-profile-handler
16+
parameters:
17+
threshold: 100
18+
decodeProfile: "decode"
19+
prefillProfile: "prefill"
20+
hashBlockSize: 5
21+
22+
schedulingProfiles:
23+
- name: decode
24+
plugins:
25+
- pluginRef: decode-filter
26+
- pluginRef: prefix-cache-scorer
27+
- pluginRef: pd-slo-scorer
28+
weight: 100
29+
- pluginRef: max-score-picker
30+
31+
- name: prefill
32+
plugins:
33+
- pluginRef: prefill-filter
34+
- pluginRef: prefix-cache-scorer
35+
- pluginRef: pd-slo-scorer
36+
weight: 100
37+
- pluginRef: max-score-picker

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ require (
2929
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a
3030
)
3131

32+
replace sigs.k8s.io/gateway-api-inference-extension => github.com/RishabhSaini/gateway-api-inference-extension v0.0.0-20260202150317-4d55e2564b01
33+
3234
require (
3335
cel.dev/expr v0.24.0 // indirect
3436
github.com/Masterminds/semver/v3 v3.4.0 // indirect

go.sum

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgv
1616
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
1717
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
1818
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
19+
github.com/RishabhSaini/gateway-api-inference-extension v0.0.0-20260202150317-4d55e2564b01 h1:TWmpkx/DH6LasXPCGYkbyIugalQuiEvcZTvw6qWb7v8=
20+
github.com/RishabhSaini/gateway-api-inference-extension v0.0.0-20260202150317-4d55e2564b01/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU=
1921
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0=
2022
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b/go.mod h1:fvzegU4vN3H1qMT+8wDmzjAcDONcgo2/SZ/TyfdUOFs=
2123
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
@@ -211,8 +213,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
211213
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
212214
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
213215
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
214-
github.com/llm-d/llm-d-kv-cache v0.5.0-rc1 h1:UkJZU8hGRdZKPeCiXnuGjLivqIS6yeFAl9pv4QDQcWY=
215-
github.com/llm-d/llm-d-kv-cache v0.5.0-rc1/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A=
216216
github.com/llm-d/llm-d-kv-cache v0.5.0 h1:XQpkbg1yedGxn2w7QS/v/2YtrOZGp16Sw49KvMlQ1s0=
217217
github.com/llm-d/llm-d-kv-cache v0.5.0/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A=
218218
github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo=
@@ -448,10 +448,6 @@ sigs.k8s.io/controller-runtime v0.22.5 h1:v3nfSUMowX/2WMp27J9slwGFyAt7IV0YwBxAkr
448448
sigs.k8s.io/controller-runtime v0.22.5/go.mod h1:pc5SoYWnWI6I+cBHYYdZ7B6YHZVY5xNfll88JB+vniI=
449449
sigs.k8s.io/gateway-api v1.4.1 h1:NPxFutNkKNa8UfLd2CMlEuhIPMQgDQ6DXNKG9sHbJU8=
450450
sigs.k8s.io/gateway-api v1.4.1/go.mod h1:AR5RSqciWP98OPckEjOjh2XJhAe2Na4LHyXD2FUY7Qk=
451-
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128073548-aea9ebe8cea3 h1:sobxO5HxXOd9RdhIUbUP0p+rZyn3ZFJAL6NolaHx1ZQ=
452-
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128073548-aea9ebe8cea3/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU=
453-
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a h1:Ce5CZ0R3c5H475uEuJ92FMgux3j99wDrSsI4ivTBEXQ=
454-
sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU=
455451
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
456452
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
457453
sigs.k8s.io/kustomize/api v0.21.0 h1:I7nry5p8iDJbuRdYS7ez8MUvw7XVNPcIP5GkzzuXIIQ=

pkg/plugins/register.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@ import (
1010

1111
// RegisterAllPlugins registers the factory functions of all plugins in this repository.
1212
func RegisterAllPlugins() {
13-
plugin.Register(filter.ByLabelType, filter.ByLabelFactory)
14-
plugin.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory)
15-
plugin.Register(filter.DecodeRoleType, filter.DecodeRoleFactory)
16-
plugin.Register(filter.PrefillRoleType, filter.PrefillRoleFactory)
17-
plugin.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory)
18-
plugin.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory)
19-
plugin.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory)
20-
plugin.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory)
21-
plugin.Register(scorer.LoadAwareType, scorer.LoadAwareFactory)
22-
plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
23-
plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
24-
plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
13+
plugins.Register(filter.ByLabelType, filter.ByLabelFactory)
14+
plugins.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory)
15+
plugins.Register(filter.DecodeRoleType, filter.DecodeRoleFactory)
16+
plugins.Register(filter.PrefillRoleType, filter.PrefillRoleFactory)
17+
plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory)
18+
plugins.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory)
19+
plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory)
20+
plugins.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory)
21+
plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory)
22+
plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
23+
plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
24+
plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
25+
plugins.Register(scorer.PDSLOAwareScorerType, scorer.PDSLOAwareScorerFactory)
2526
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Package scorer provides scoring plugins for the llm-d scheduler.
2+
package scorer
3+
4+
import (
5+
"context"
6+
"time"
7+
8+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
9+
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
10+
predictedlatency "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency"
11+
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
12+
13+
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
14+
)
15+
16+
// PDPredictionRequestBuilder extends the default builder with P/D pod type awareness.
17+
// This builder reads the llm-d.ai/role label from pods and populates the PodType field
18+
// in prediction and training requests, enabling the latency predictor to learn separate
19+
// models for prefill and decode workloads.
20+
type PDPredictionRequestBuilder struct {
21+
predictedlatency.DefaultPredictionRequestBuilder
22+
}
23+
24+
// NewPDPredictionRequestBuilder creates a new P/D-aware prediction request builder.
25+
func NewPDPredictionRequestBuilder() *PDPredictionRequestBuilder {
26+
return &PDPredictionRequestBuilder{}
27+
}
28+
29+
// extractPodType reads the llm-d.ai/role label from a pod and maps it to the predictor's pod_type field.
30+
// Returns:
31+
// - "prefill" for pods with llm-d.ai/role=prefill
32+
// - "decode" for pods with llm-d.ai/role=decode
33+
// - "" (empty) for pods with llm-d.ai/role=both or no label (monolithic)
34+
func (b *PDPredictionRequestBuilder) extractPodType(pod schedulingtypes.Endpoint) string {
35+
// Get pod labels from the underlying endpoint metadata
36+
backendPod := pod.GetMetadata()
37+
if backendPod == nil {
38+
return "" // No pod info, treat as monolithic
39+
}
40+
41+
labels := backendPod.Labels
42+
if labels == nil {
43+
return "" // No labels, treat as monolithic
44+
}
45+
46+
role, exists := labels[filter.RoleLabel] // "llm-d.ai/role"
47+
if !exists {
48+
return "" // No role label, treat as monolithic
49+
}
50+
51+
// Map llm-d roles to predictor pod types
52+
switch role {
53+
case filter.RolePrefill:
54+
return "prefill"
55+
case filter.RoleDecode:
56+
return "decode"
57+
case filter.RoleBoth:
58+
// Pods that can do both are treated as monolithic
59+
// (predictor doesn't have a specialized model for this)
60+
return ""
61+
default:
62+
return ""
63+
}
64+
}
65+
66+
// BuildPredictionRequest constructs a prediction request with pod type information.
67+
// Extends the default implementation by populating the PodType field based on the pod's role label.
68+
func (b *PDPredictionRequestBuilder) BuildPredictionRequest(
69+
ctx context.Context,
70+
pod schedulingtypes.Endpoint,
71+
metrics *datalayer.Metrics,
72+
prompt string,
73+
generatedTokens int,
74+
prefixCacheScore float64,
75+
) latencypredictor.PredictionRequest {
76+
// Get base request from parent implementation
77+
req := b.DefaultPredictionRequestBuilder.BuildPredictionRequest(
78+
ctx, pod, metrics, prompt, generatedTokens, prefixCacheScore,
79+
)
80+
81+
// Customize with pod type from llm-d.ai/role label
82+
req.PodType = b.extractPodType(pod)
83+
84+
return req
85+
}
86+
87+
// BuildTrainingEntry constructs a training entry with pod type information.
88+
// Extends the default implementation by populating the PodType field based on the pod's role label.
89+
func (b *PDPredictionRequestBuilder) BuildTrainingEntry(
90+
ctx context.Context,
91+
pod schedulingtypes.Endpoint,
92+
metrics *datalayer.Metrics,
93+
prompt string,
94+
actualTTFT float64,
95+
actualTPOT float64,
96+
timestamp time.Time,
97+
generatedTokens int,
98+
prefixCacheScore float64,
99+
) latencypredictor.TrainingEntry {
100+
// Get base entry from parent implementation
101+
entry := b.DefaultPredictionRequestBuilder.BuildTrainingEntry(
102+
ctx, pod, metrics, prompt, actualTTFT, actualTPOT, timestamp, generatedTokens, prefixCacheScore,
103+
)
104+
105+
// Customize with pod type from llm-d.ai/role label
106+
entry.PodType = b.extractPodType(pod)
107+
108+
return entry
109+
}
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
Copyright 2025 The llm-d Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package scorer
18+
19+
import (
20+
"context"
21+
"strconv"
22+
"time"
23+
24+
"sigs.k8s.io/controller-runtime/pkg/log"
25+
26+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
27+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
28+
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
29+
predictedlatency "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/predictedlatency"
30+
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
31+
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
32+
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
33+
)
34+
35+
// PDSLOAwareRouter wraps the base PredictedLatency to add P/D-specific hook logic.
36+
// This keeps P/D disaggregation concerns in llm-d-inference-scheduler rather than
37+
// leaking them into the generic gateway-api-inference-extension.
38+
type PDSLOAwareRouter struct {
39+
*predictedlatency.PredictedLatency
40+
}
41+
42+
var _ requestcontrol.PreRequest = &PDSLOAwareRouter{}
43+
var _ requestcontrol.ResponseReceived = &PDSLOAwareRouter{}
44+
var _ requestcontrol.ResponseStreaming = &PDSLOAwareRouter{}
45+
var _ requestcontrol.ResponseComplete = &PDSLOAwareRouter{}
46+
47+
// PreRequest delegates to the base router
48+
func (p *PDSLOAwareRouter) PreRequest(ctx context.Context, request *schedulingtypes.LLMRequest, schedulingResult *schedulingtypes.SchedulingResult) {
49+
p.PredictedLatency.PreRequest(ctx, request, schedulingResult)
50+
}
51+
52+
// ResponseReceived adds P/D-specific logic to extract prefill timing headers
53+
// before delegating to the base router.
54+
func (p *PDSLOAwareRouter) ResponseReceived(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) {
55+
logger := log.FromContext(ctx)
56+
57+
// P/D-specific: Check for prefill timing headers from the decode sidecar
58+
if prefillTTFTStr, ok := response.Headers["x-prefill-ttft-ms"]; ok && prefillTTFTStr != "" {
59+
logger.V(logutil.DEBUG).Info("Detected prefill timing header",
60+
"ttft_ms", prefillTTFTStr,
61+
"requestID", request.Headers[requtil.RequestIdHeaderKey])
62+
63+
// Parse prefill TTFT
64+
prefillTTFT, err := strconv.ParseFloat(prefillTTFTStr, 64)
65+
if err != nil {
66+
logger.V(logutil.DEBUG).Error(err, "Failed to parse prefill TTFT header", "value", prefillTTFTStr)
67+
} else {
68+
// Record training data for the prefill pod
69+
p.recordPrefillTrainingData(ctx, request, prefillTTFT)
70+
}
71+
}
72+
73+
// Delegate to base router for decode prediction logic
74+
p.PredictedLatency.ResponseReceived(ctx, request, response, targetPod)
75+
}
76+
77+
// ResponseStreaming delegates to the base router
78+
func (p *PDSLOAwareRouter) ResponseStreaming(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *datalayer.EndpointMetadata) {
79+
p.PredictedLatency.ResponseStreaming(ctx, request, response, pod)
80+
}
81+
82+
// ResponseComplete delegates to the base router
83+
func (p *PDSLOAwareRouter) ResponseComplete(ctx context.Context, request *schedulingtypes.LLMRequest, response *requestcontrol.Response, pod *datalayer.EndpointMetadata) {
84+
p.PredictedLatency.ResponseComplete(ctx, request, response, pod)
85+
}
86+
87+
// recordPrefillTrainingData records training data for the prefill pod based on timing
88+
// reported by the decode sidecar via x-prefill-ttft-ms header.
89+
//
90+
// This method is P/D-specific and lives in llm-d-inference-scheduler because it:
91+
// - Assumes two-phase scheduling with "prefill" and "decode" profiles
92+
// - Knows about the llm-d.ai/role label structure
93+
// - Understands that prefill pods only handle TTFT (no TPOT)
94+
func (p *PDSLOAwareRouter) recordPrefillTrainingData(
95+
ctx context.Context,
96+
request *schedulingtypes.LLMRequest,
97+
actualPrefillTTFT float64,
98+
) {
99+
logger := log.FromContext(ctx)
100+
101+
// Get scheduling result for this request
102+
schedulingResult, err := p.PredictedLatency.GetSchedulingResultForRequest(request)
103+
if err != nil {
104+
logger.V(logutil.DEBUG).Error(err, "Failed to get scheduling result for prefill training")
105+
return
106+
}
107+
108+
// P/D-specific: Extract prefill pod from the "prefill" profile
109+
prefillResult, exists := schedulingResult.ProfileResults["prefill"]
110+
if !exists || prefillResult == nil || len(prefillResult.TargetPods) == 0 {
111+
logger.V(logutil.DEBUG).Info("No prefill pod in scheduling result, skipping prefill training")
112+
return
113+
}
114+
115+
prefillPod := prefillResult.TargetPods[0]
116+
117+
// Get metrics for the prefill pod
118+
lastSeenMetrics, err := p.PredictedLatency.GetLastSeenMetricsForRequest(request)
119+
if err != nil {
120+
logger.V(logutil.DEBUG).Error(err, "Failed to get metrics for prefill training")
121+
return
122+
}
123+
124+
prefillMetrics, exists := lastSeenMetrics["prefill"]
125+
if !exists || prefillMetrics == nil {
126+
logger.V(logutil.DEBUG).Info("No metrics available for prefill pod")
127+
return
128+
}
129+
130+
// Get prefix cache score
131+
prefixCacheScores, err := p.PredictedLatency.GetPrefixCacheScoresForRequest(request)
132+
if err != nil {
133+
logger.V(logutil.DEBUG).Error(err, "Failed to get prefix cache scores")
134+
return
135+
}
136+
prefixCacheScore := prefixCacheScores[prefillPod.GetMetadata().String()]
137+
138+
// Get prompt
139+
prompt, err := p.PredictedLatency.GetRequestPrompt(request)
140+
if err != nil {
141+
logger.V(logutil.DEBUG).Error(err, "Failed to get prompt for prefill training")
142+
return
143+
}
144+
145+
// Build training entry using the PDPredictionRequestBuilder
146+
// This will automatically populate PodType="prefill" based on llm-d.ai/role label
147+
requestBuilder := p.PredictedLatency.GetRequestBuilder()
148+
entry := requestBuilder.BuildTrainingEntry(
149+
ctx,
150+
prefillPod,
151+
prefillMetrics,
152+
prompt,
153+
actualPrefillTTFT, // Actual TTFT from sidecar
154+
0, // TPOT not applicable for prefill
155+
time.Now(),
156+
0, // No tokens generated yet for prefill
157+
prefixCacheScore,
158+
)
159+
160+
// Record training data
161+
latencyPredictor := p.PredictedLatency.GetLatencyPredictor().(latencypredictor.PredictorInterface)
162+
if err := latencyPredictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
163+
logger.V(logutil.DEBUG).Error(err, "Failed to record prefill training data")
164+
} else {
165+
logger.V(logutil.DEBUG).Info("Recorded prefill training data",
166+
"pod", prefillPod.GetPod().String(),
167+
"ttft_ms", actualPrefillTTFT,
168+
"pod_type", "prefill")
169+
}
170+
}

0 commit comments

Comments
 (0)