Skip to content

Commit 423ee58

Browse files
committed
Add the EndpointRoleLabel as a parameter for the predicted-latency-scorer
Get rid of the RequestBuilderStruct fashion and use helper funcs instead
1 parent 3be0cb9 commit 423ee58

File tree

7 files changed

+56
-95
lines changed

7 files changed

+56
-95
lines changed

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper.go

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func getLatestMetricsForProfile(predictedLatencyCtx *predictedLatencyCtx) (*fwkd
6565
func processPreRequestForLatencyPrediction(
6666
ctx context.Context,
6767
predictor latencypredictor.PredictorInterface,
68-
requestBuilder PredictionRequestBuilder,
68+
endpointRoleLabel string,
6969
predictedLatencyCtx *predictedLatencyCtx,
7070
) error {
7171
logger := log.FromContext(ctx)
@@ -83,9 +83,9 @@ func processPreRequestForLatencyPrediction(
8383
target_endpoint_metadata := predictedLatencyCtx.targetMetadata
8484
prefix_cache_score := predictedLatencyCtx.prefixCacheScoresForEndpoints[target_endpoint_metadata.NamespacedName.Name]
8585

86-
// Build prediction request using the builder (ensures pod type is included for P/D)
87-
in := requestBuilder.BuildPredictionRequest(
88-
ctx,
86+
// Build prediction request (pod type is included if endpointRoleLabel is configured)
87+
in := buildPredictionRequest(
88+
endpointRoleLabel,
8989
target_endpoint_metadata,
9090
m,
9191
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
@@ -121,7 +121,7 @@ func processFirstTokenForLatencyPrediction(
121121
ctx context.Context,
122122
predictor latencypredictor.PredictorInterface,
123123
streamingMode bool,
124-
requestBuilder PredictionRequestBuilder,
124+
endpointRoleLabel string,
125125
predictedLatencyCtx *predictedLatencyCtx,
126126
now time.Time,
127127
samplingMean float64,
@@ -141,10 +141,10 @@ func processFirstTokenForLatencyPrediction(
141141
targetEndpointMetadata := predictedLatencyCtx.targetMetadata
142142
prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetEndpointMetadata.NamespacedName.Name]
143143
logger.V(logutil.DEBUG).Info("Recording TTFT training data", "ttft_ms", predictedLatencyCtx.ttft, "prefixCacheScore", prefixCacheScore)
144-
recordTTFTTrainingData(ctx, predictor, requestBuilder, predictedLatencyCtx, m, targetEndpointMetadata, now, prefixCacheScore)
144+
recordTTFTTrainingData(ctx, predictor, endpointRoleLabel, predictedLatencyCtx, m, targetEndpointMetadata, now, prefixCacheScore)
145145

146146
if streamingMode {
147-
predictFirstTPOT(ctx, predictor, requestBuilder, predictedLatencyCtx, targetEndpointMetadata)
147+
predictFirstTPOT(ctx, predictor, endpointRoleLabel, predictedLatencyCtx, targetEndpointMetadata)
148148
}
149149

150150
// Advance timestamp
@@ -165,17 +165,16 @@ func initializeSampler(ctx context.Context, predictedLatencyCtx *predictedLatenc
165165
func recordTTFTTrainingData(
166166
ctx context.Context,
167167
predictor latencypredictor.PredictorInterface,
168-
requestBuilder PredictionRequestBuilder,
168+
endpointRoleLabel string,
169169
predictedLatencyCtx *predictedLatencyCtx,
170170
m *fwkdl.Metrics,
171171
targetEndpointMetadata *fwkdl.EndpointMetadata,
172172
now time.Time,
173173
prefixCacheScore float64,
174174
) {
175175
logger := log.FromContext(ctx)
176-
// Build training entry using the builder
177-
entry := requestBuilder.BuildTrainingEntry(
178-
ctx,
176+
entry := buildTrainingEntry(
177+
endpointRoleLabel,
179178
targetEndpointMetadata,
180179
m,
181180
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
@@ -193,7 +192,7 @@ func recordTTFTTrainingData(
193192
func predictFirstTPOT(
194193
ctx context.Context,
195194
predictor latencypredictor.PredictorInterface,
196-
requestBuilder PredictionRequestBuilder,
195+
endpointRoleLabel string,
197196
predictedLatencyCtx *predictedLatencyCtx,
198197
targetEndpointMetadata *fwkdl.EndpointMetadata,
199198
) {
@@ -205,9 +204,8 @@ func predictFirstTPOT(
205204
return
206205
}
207206

208-
// Build prediction request using the builder (ensures pod type is included for P/D)
209-
in := requestBuilder.BuildPredictionRequest(
210-
ctx,
207+
in := buildPredictionRequest(
208+
endpointRoleLabel,
211209
targetEndpointMetadata,
212210
m,
213211
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
@@ -233,7 +231,7 @@ func predictFirstTPOT(
233231
func processTokenForLatencyPrediction(
234232
ctx context.Context,
235233
predictor latencypredictor.PredictorInterface,
236-
requestBuilder PredictionRequestBuilder,
234+
endpointRoleLabel string,
237235
predictedLatencyCtx *predictedLatencyCtx,
238236
targetEndpointMetadata *fwkdl.EndpointMetadata,
239237
now time.Time,
@@ -265,9 +263,8 @@ func processTokenForLatencyPrediction(
265263
"error", err)
266264
return
267265
}
268-
// Record actual TPOT using builder
269-
entry := requestBuilder.BuildTrainingEntry(
270-
ctx,
266+
entry := buildTrainingEntry(
267+
endpointRoleLabel,
271268
targetEndpointMetadata,
272269
m,
273270
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
@@ -283,9 +280,8 @@ func processTokenForLatencyPrediction(
283280

284281
// Sampled predict
285282
if predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) {
286-
// Build prediction request using the builder (ensures pod type is included for P/D)
287-
in := requestBuilder.BuildPredictionRequest(
288-
ctx,
283+
in := buildPredictionRequest(
284+
endpointRoleLabel,
289285
targetEndpointMetadata,
290286
m,
291287
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
@@ -321,7 +317,7 @@ func bulkPredictWithMetrics(
321317
ctx context.Context,
322318
predictor latencypredictor.PredictorInterface,
323319
metricsStates []*fwkdl.Metrics,
324-
requestBuilder PredictionRequestBuilder,
320+
endpointRoleLabel string,
325321
targetEndpointsMetadatas []*fwkdl.EndpointMetadata,
326322
prompts []string,
327323
generatedTokenCounts []int,
@@ -353,11 +349,11 @@ func bulkPredictWithMetrics(
353349
}
354350
}
355351

356-
// Build bulk prediction requests using the builder
352+
// Build bulk prediction requests
357353
bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates))
358354
for i := range metricsStates {
359-
bulkRequests[i] = requestBuilder.BuildPredictionRequest(
360-
ctx,
355+
bulkRequests[i] = buildPredictionRequest(
356+
endpointRoleLabel,
361357
targetEndpointsMetadatas[i],
362358
metricsStates[i],
363359
prompts[i],

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/latencypredictor_helper_test.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ func TestBulkPredictWithMetrics(t *testing.T) {
4040
{KVCacheUsagePercent: 0.5},
4141
{KVCacheUsagePercent: 0.6},
4242
}
43-
requestBuilder := &DefaultPredictionRequestBuilder{}
4443
pods := []*fwkdl.EndpointMetadata{
4544
{
4645
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
@@ -53,7 +52,7 @@ func TestBulkPredictWithMetrics(t *testing.T) {
5352
generatedTokenCounts := []int{1, 1}
5453
prefixCacheScores := []float64{0.0, 0.0}
5554

56-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
55+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)
5756

5857
assert.NoError(t, err)
5958
assert.Len(t, results, 2)
@@ -71,7 +70,6 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
7170
metricsStates := []*fwkdl.Metrics{
7271
{KVCacheUsagePercent: 0.5},
7372
}
74-
requestBuilder := &DefaultPredictionRequestBuilder{}
7573
pods := []*fwkdl.EndpointMetadata{
7674
{
7775
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
@@ -81,7 +79,7 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
8179
generatedTokenCounts := []int{1}
8280
prefixCacheScores := []float64{0.0}
8381

84-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
82+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)
8583

8684
assert.Error(t, err)
8785
assert.Nil(t, results)
@@ -90,7 +88,6 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
9088
func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
9189
mockPredictor := &mockPredictor{}
9290
metricsStates := []*fwkdl.Metrics{{}}
93-
requestBuilder := &DefaultPredictionRequestBuilder{}
9491
pods := []*fwkdl.EndpointMetadata{
9592
{
9693
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
@@ -100,7 +97,7 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
10097
generatedTokenCounts := []int{1}
10198
prefixCacheScores := []float64{0.0}
10299

103-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
100+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)
104101

105102
assert.Error(t, err)
106103
assert.Nil(t, results)
@@ -110,7 +107,6 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
110107
func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) {
111108
mockPredictor := &mockPredictor{}
112109
metricsStates := []*fwkdl.Metrics{nil} // Nil metrics state
113-
requestBuilder := &DefaultPredictionRequestBuilder{}
114110
pods := []*fwkdl.EndpointMetadata{
115111
{
116112
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
@@ -120,7 +116,7 @@ func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) {
120116
generatedTokenCounts := []int{1}
121117
prefixCacheScores := []float64{0.0}
122118

123-
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, requestBuilder, pods, prompts, generatedTokenCounts, prefixCacheScores)
119+
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)
124120

125121
assert.Error(t, err)
126122
assert.Nil(t, results)

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/prediction.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (s *PredictedLatency) generatePredictions(ctx context.Context, request *sch
6969
}
7070

7171
// Bulk predict
72-
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, s.requestBuilder, targetEndpointsMetadatas, prompts, generatedTokenCounts, prefixCacheScores)
72+
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, s.config.EndpointRoleLabel, targetEndpointsMetadatas, prompts, generatedTokenCounts, prefixCacheScores)
7373
if err != nil {
7474
logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed")
7575
return nil, err

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ func (t *PredictedLatency) PreRequest(ctx context.Context, request *schedulingty
187187
refreshLastSeenMetrics(ctx, predictedLatencyCtx)
188188
t.setPredictedLatencyContextForRequest(request, predictedLatencyCtx)
189189

190-
if err := processPreRequestForLatencyPrediction(ctx, t.latencypredictor, t.requestBuilder, predictedLatencyCtx); err != nil {
190+
if err := processPreRequestForLatencyPrediction(ctx, t.latencypredictor, t.config.EndpointRoleLabel, predictedLatencyCtx); err != nil {
191191
logger.V(logutil.DEBUG).Error(err, "Process PreRequest in latencypredictor failed")
192192
}
193193
}
@@ -233,8 +233,8 @@ func (t *PredictedLatency) ResponseReceived(ctx context.Context, request *schedu
233233
"prefillPod", prefillMetadata.NamespacedName.Name,
234234
"prefixCacheScore", prefixCacheScore)
235235

236-
entry := t.requestBuilder.BuildTrainingEntry(
237-
ctx,
236+
entry := buildTrainingEntry(
237+
t.config.EndpointRoleLabel,
238238
prefillMetadata,
239239
prefillMetrics,
240240
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
@@ -272,9 +272,9 @@ func (t *PredictedLatency) ResponseStreaming(ctx context.Context, request *sched
272272
}
273273

274274
if predictedLatencyCtx.ttft == 0 {
275-
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
275+
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.config.EndpointRoleLabel, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
276276
} else {
277-
processTokenForLatencyPrediction(ctx, t.latencypredictor, t.requestBuilder, predictedLatencyCtx, targetMetadata, now, t.config.SamplingMean, t.config.MaxSampledTokens)
277+
processTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.EndpointRoleLabel, predictedLatencyCtx, targetMetadata, now, t.config.SamplingMean, t.config.MaxSampledTokens)
278278
}
279279

280280
}
@@ -298,7 +298,7 @@ func (t *PredictedLatency) ResponseComplete(ctx context.Context, request *schedu
298298
}
299299
now := time.Now()
300300
if !t.config.StreamingMode {
301-
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.requestBuilder, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
301+
processFirstTokenForLatencyPrediction(ctx, t.latencypredictor, t.config.StreamingMode, t.config.EndpointRoleLabel, predictedLatencyCtx, now, t.config.SamplingMean, t.config.MaxSampledTokens)
302302
}
303303

304304
if predictedLatencyCtx.ttft > 0 {

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/requestcontrol_hooks_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ func createTestRouter() *PredictedLatency {
6464
),
6565
// runningRequestLists is a sync.Map and needs no initialization
6666
latencypredictor: nil,
67-
requestBuilder: &DefaultPredictionRequestBuilder{},
6867
config: DefaultConfig,
6968
}
7069
}

pkg/epp/framework/plugins/scheduling/scorer/predictedlatency/scorer.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ import (
4040
type PredictedLatency struct {
4141
typedName plugin.TypedName
4242
latencypredictor latencypredictor.PredictorInterface
43-
requestBuilder PredictionRequestBuilder
4443
runningRequestLists sync.Map // Key: types.NamespacedName, Value: *requestPriorityQueue
4544
sloContextStore *ttlcache.Cache[string, *predictedLatencyCtx] // TTL cache for request contexts
4645
headroomStrategy headroomStrategy
@@ -68,6 +67,7 @@ type Config struct {
6867
ContextTTL time.Duration `json:"contextTTL,omitempty"`
6968
SelectionMode string `json:"selectionMode,omitempty"`
7069
StreamingMode bool `json:"streamingMode,omitempty"`
70+
EndpointRoleLabel string `json:"endpointRoleLabel,omitempty"`
7171
}
7272

7373
var DefaultConfig = Config{
@@ -164,7 +164,6 @@ func NewPredictedLatency(config Config, predictor latencypredictor.PredictorInte
164164
predictedLatency := &PredictedLatency{
165165
typedName: plugin.TypedName{Type: PredictedLatencyPluginType, Name: PredictedLatencyPluginType},
166166
latencypredictor: predictor,
167-
requestBuilder: &DefaultPredictionRequestBuilder{}, // Default, can be customized via SetRequestBuilder
168167
// runningRequestLists is a sync.Map and needs no initialization
169168
headroomStrategy: strategy,
170169
config: config,
@@ -215,16 +214,6 @@ func (s *PredictedLatency) WithName(name string) *PredictedLatency {
215214
return s
216215
}
217216

218-
// SetRequestBuilder sets a custom prediction request builder.
219-
// This allows external packages (e.g., llm-d-inference-scheduler) to customize
220-
// how prediction and training requests are constructed, for example to add
221-
// pod type information for disaggregated serving scenarios.
222-
func (s *PredictedLatency) SetRequestBuilder(builder PredictionRequestBuilder) {
223-
if builder != nil {
224-
s.requestBuilder = builder
225-
}
226-
}
227-
228217
func (s *PredictedLatency) epsilonGreedyAffinityGate(
229218
ctx context.Context,
230219
candidates []endpointPredictionResult,

0 commit comments

Comments
 (0)