Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"

"sigs.k8s.io/controller-runtime/pkg/log"
Expand Down Expand Up @@ -66,6 +65,7 @@ func getLatestMetricsForProfile(predictedLatencyCtx *predictedLatencyCtx) (*fwkd
func processPreRequestForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
endpointRoleLabel string,
predictedLatencyCtx *predictedLatencyCtx,
) error {
logger := log.FromContext(ctx)
Expand All @@ -80,17 +80,18 @@ func processPreRequestForLatencyPrediction(
return err
}

targetPod := predictedLatencyCtx.targetMetadata
prefix_cache_score := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetPod.NamespacedName.Name]
target_endpoint_metadata := predictedLatencyCtx.targetMetadata
prefix_cache_score := predictedLatencyCtx.prefixCacheScoresForEndpoints[target_endpoint_metadata.NamespacedName.Name]

in := latencypredictor.PredictionRequest{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: 0,
PrefixCacheScore: prefix_cache_score,
}
// Build prediction request (pod type is included if endpointRoleLabel is configured)
in := buildPredictionRequest(
endpointRoleLabel,
target_endpoint_metadata,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
0, // NumTokensGenerated is 0 for pre-request TTFT prediction
prefix_cache_score,
)

// Predict TTFT
start := time.Now()
Expand Down Expand Up @@ -120,6 +121,7 @@ func processFirstTokenForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
streamingMode bool,
endpointRoleLabel string,
predictedLatencyCtx *predictedLatencyCtx,
now time.Time,
samplingMean float64,
Expand All @@ -136,13 +138,13 @@ func processFirstTokenForLatencyPrediction(
logger.V(logutil.DEBUG).Info("Skipping prediction due to missing metrics", "error", err)
return
}
targetPod := predictedLatencyCtx.targetMetadata
prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetPod.NamespacedName.Name]
targetEndpointMetadata := predictedLatencyCtx.targetMetadata
prefixCacheScore := predictedLatencyCtx.prefixCacheScoresForEndpoints[targetEndpointMetadata.NamespacedName.Name]
logger.V(logutil.DEBUG).Info("Recording TTFT training data", "ttft_ms", predictedLatencyCtx.ttft, "prefixCacheScore", prefixCacheScore)
recordTTFTTrainingData(ctx, predictor, predictedLatencyCtx, m, now, prefixCacheScore)
recordTTFTTrainingData(ctx, predictor, endpointRoleLabel, predictedLatencyCtx, m, targetEndpointMetadata, now, prefixCacheScore)

if streamingMode {
predictFirstTPOT(ctx, predictor, predictedLatencyCtx)
predictFirstTPOT(ctx, predictor, endpointRoleLabel, predictedLatencyCtx, targetEndpointMetadata)
}

// Advance timestamp
Expand All @@ -163,24 +165,25 @@ func initializeSampler(ctx context.Context, predictedLatencyCtx *predictedLatenc
func recordTTFTTrainingData(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
endpointRoleLabel string,
predictedLatencyCtx *predictedLatencyCtx,
m *fwkdl.Metrics,
targetEndpointMetadata *fwkdl.EndpointMetadata,
now time.Time,
prefixCacheScore float64,
) {
logger := log.FromContext(ctx)
// Train TTFT
entry := latencypredictor.TrainingEntry{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
ActualTTFT: predictedLatencyCtx.ttft,
ActualTPOT: 0,
Timestamp: now,
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: 0,
PrefixCacheScore: prefixCacheScore,
}
entry := buildTrainingEntry(
endpointRoleLabel,
targetEndpointMetadata,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
predictedLatencyCtx.ttft,
0, // TTFT training
now,
0,
prefixCacheScore,
)
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TTFT training failed")
}
Expand All @@ -189,7 +192,9 @@ func recordTTFTTrainingData(
func predictFirstTPOT(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
endpointRoleLabel string,
predictedLatencyCtx *predictedLatencyCtx,
targetEndpointMetadata *fwkdl.EndpointMetadata,
) {
logger := log.FromContext(ctx)
m, err := getLatestMetricsForProfile(predictedLatencyCtx)
Expand All @@ -199,15 +204,14 @@ func predictFirstTPOT(
return
}

// Predict first TPOT
in := latencypredictor.PredictionRequest{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: predictedLatencyCtx.generatedTokenCount,
PrefixCacheScore: 0,
}
in := buildPredictionRequest(
endpointRoleLabel,
targetEndpointMetadata,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
predictedLatencyCtx.generatedTokenCount,
0, // TPOT does not use prefix cache score
)
start := time.Now()
p, err := predictor.Predict(ctx, in)
dur := time.Since(start)
Expand All @@ -227,7 +231,9 @@ func predictFirstTPOT(
func processTokenForLatencyPrediction(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
endpointRoleLabel string,
predictedLatencyCtx *predictedLatencyCtx,
targetEndpointMetadata *fwkdl.EndpointMetadata,
now time.Time,
samplingMean float64,
maxSampledTokens int,
Expand Down Expand Up @@ -257,32 +263,31 @@ func processTokenForLatencyPrediction(
"error", err)
return
}
// Record actual TPOT
entry := latencypredictor.TrainingEntry{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
ActualTTFT: 0,
ActualTPOT: latencyMs,
Timestamp: now,
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: predictedLatencyCtx.generatedTokenCount - 1,
PrefixCacheScore: 0, // TPOT does not use prefix cache score
}
entry := buildTrainingEntry(
endpointRoleLabel,
targetEndpointMetadata,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
0, // TTFT not recorded for TPOT
latencyMs,
now,
predictedLatencyCtx.generatedTokenCount-1,
0, // TPOT does not use prefix cache score
)
if err := predictor.AddTrainingDataBulk([]latencypredictor.TrainingEntry{entry}); err != nil {
logger.V(logutil.DEBUG).Error(err, "record TPOT training failed")
}

// Sampled predict
if predictedLatencyCtx.tokenSampler.shouldPredict(predictedLatencyCtx.generatedTokenCount) {
in := latencypredictor.PredictionRequest{
KVCachePercentage: m.KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt)),
NumRequestWaiting: m.WaitingQueueSize,
NumRequestRunning: m.RunningRequestsSize,
NumTokensGenerated: predictedLatencyCtx.generatedTokenCount,
PrefixCacheScore: 0, // TPOT does not use prefix cache score
}
in := buildPredictionRequest(
endpointRoleLabel,
targetEndpointMetadata,
m,
predictedLatencyCtx.schedulingRequest.Body.Completions.Prompt,
predictedLatencyCtx.generatedTokenCount,
0, // TPOT does not use prefix cache score
)
start := time.Now()
p, err := predictor.Predict(ctx, in)
dur := time.Since(start)
Expand Down Expand Up @@ -312,16 +317,18 @@ func bulkPredictWithMetrics(
ctx context.Context,
predictor latencypredictor.PredictorInterface,
metricsStates []*fwkdl.Metrics,
endpointRoleLabel string,
targetEndpointsMetadatas []*fwkdl.EndpointMetadata,
prompts []string,
generatedTokenCounts []int,
prefixCacheScores []float64,
) ([]*latencypredictor.PredictionResponse, error) {
logger := log.FromContext(ctx)

// Validate input lengths
if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) {
return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d",
len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores))
if len(targetEndpointsMetadatas) != len(metricsStates) || len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) {
return nil, fmt.Errorf("input slice lengths must match: endpoints=%d, metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d",
len(targetEndpointsMetadatas), len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores))
}

if len(metricsStates) == 0 {
Expand All @@ -335,17 +342,24 @@ func bulkPredictWithMetrics(
}
}

// Validate that no endpoint metadata is nil
for i, endpointMetadata := range targetEndpointsMetadatas {
if endpointMetadata == nil {
return nil, fmt.Errorf("endpoint metadata at index %d cannot be nil", i)
}
}

// Build bulk prediction requests
bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates))
for i := range metricsStates {
bulkRequests[i] = latencypredictor.PredictionRequest{
KVCachePercentage: metricsStates[i].KVCacheUsagePercent,
InputTokenLength: len(strings.Fields(prompts[i])),
NumRequestWaiting: metricsStates[i].WaitingQueueSize,
NumRequestRunning: metricsStates[i].RunningRequestsSize,
NumTokensGenerated: generatedTokenCounts[i],
PrefixCacheScore: prefixCacheScores[i],
}
bulkRequests[i] = buildPredictionRequest(
endpointRoleLabel,
targetEndpointsMetadatas[i],
metricsStates[i],
prompts[i],
generatedTokenCounts[i],
prefixCacheScores[i],
)
}

// Perform bulk prediction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/types"
fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
)
Expand All @@ -39,11 +40,19 @@ func TestBulkPredictWithMetrics(t *testing.T) {
{KVCacheUsagePercent: 0.5},
{KVCacheUsagePercent: 0.6},
}
pods := []*fwkdl.EndpointMetadata{
{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
},
{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod2"},
},
}
prompts := []string{"prompt1", "prompt2"}
generatedTokenCounts := []int{1, 1}
prefixCacheScores := []float64{0.0, 0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.NoError(t, err)
assert.Len(t, results, 2)
Expand All @@ -61,11 +70,16 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
metricsStates := []*fwkdl.Metrics{
{KVCacheUsagePercent: 0.5},
}
pods := []*fwkdl.EndpointMetadata{
{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
},
}
prompts := []string{"prompt1"}
generatedTokenCounts := []int{1}
prefixCacheScores := []float64{0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.Error(t, err)
assert.Nil(t, results)
Expand All @@ -74,11 +88,16 @@ func TestBulkPredictWithMetrics_Error(t *testing.T) {
func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
mockPredictor := &mockPredictor{}
metricsStates := []*fwkdl.Metrics{{}}
pods := []*fwkdl.EndpointMetadata{
{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
},
}
prompts := []string{"prompt1", "prompt2"} // Mismatch length
generatedTokenCounts := []int{1}
prefixCacheScores := []float64{0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.Error(t, err)
assert.Nil(t, results)
Expand All @@ -88,11 +107,16 @@ func TestBulkPredictWithMetrics_InputMismatch(t *testing.T) {
func TestBulkPredictWithMetrics_NilMetricsState(t *testing.T) {
mockPredictor := &mockPredictor{}
metricsStates := []*fwkdl.Metrics{nil} // Nil metrics state
pods := []*fwkdl.EndpointMetadata{
{
NamespacedName: types.NamespacedName{Namespace: "default", Name: "pod1"},
},
}
prompts := []string{"prompt1"}
generatedTokenCounts := []int{1}
prefixCacheScores := []float64{0.0}

results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
results, err := bulkPredictWithMetrics(context.Background(), mockPredictor, metricsStates, "", pods, prompts, generatedTokenCounts, prefixCacheScores)

assert.Error(t, err)
assert.Nil(t, results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func (s *PredictedLatency) generatePredictions(ctx context.Context, request *sch

// Prepare inputs for bulk prediction
metricsStates := make([]*fwkdl.Metrics, len(candidateEndpoints))
targetEndpointsMetadatas := make([]*fwkdl.EndpointMetadata, len(candidateEndpoints))
prompts := make([]string, len(candidateEndpoints))
generatedTokenCounts := make([]int, len(candidateEndpoints))
prefixCacheScores := make([]float64, len(candidateEndpoints))
Expand All @@ -61,13 +62,14 @@ func (s *PredictedLatency) generatePredictions(ctx context.Context, request *sch
logger.V(logutil.DEBUG).Info("Prefix cache score for pod", "pod", endpoint.GetMetadata().String(), "prefixCacheScore", prefixCacheScore)

metricsStates[i] = endpoint.GetMetrics()
targetEndpointsMetadatas[i] = endpoint.GetMetadata()
prompts[i] = request.Body.Completions.Prompt
generatedTokenCounts[i] = 1
prefixCacheScores[i] = prefixCacheScore
}

// Bulk predict
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, prompts, generatedTokenCounts, prefixCacheScores)
bulkPredictions, err := bulkPredictWithMetrics(ctx, s.latencypredictor, metricsStates, s.config.EndpointRoleLabel, targetEndpointsMetadatas, prompts, generatedTokenCounts, prefixCacheScores)
if err != nil {
logger.V(logutil.DEBUG).Error(err, "Bulk prediction failed")
return nil, err
Expand Down
Loading