@@ -24,6 +24,7 @@ import (
2424 "math/rand"
2525 "net"
2626 "strings"
27+ "sync"
2728 "time"
2829
2930 "sigs.k8s.io/controller-runtime/pkg/log"
@@ -42,6 +43,11 @@ import (
4243 requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
4344)
4445
46+ const (
47+ prepareDataTimeout = 200 * time .Millisecond
48+ prepareDataMaxRetries = 3
49+ )
50+
4551// Datastore defines the interface required by the Director.
4652type Datastore interface {
4753 PoolGet () (* v1.InferencePool , error )
@@ -108,19 +114,19 @@ func (d *Director) getInferenceObjective(ctx context.Context, reqCtx *handlers.R
108114}
109115
110116// resolveTargetModel is a helper to update reqCtx with target model based on request.
111- func (d * Director ) resolveTargetModel (reqCtx * handlers.RequestContext ) error {
117+ func (d * Director ) resolveTargetModel (reqCtx * handlers.RequestContext ) ( * handlers. RequestContext , error ) {
112118 requestBodyMap := reqCtx .Request .Body
113119 var ok bool
114120 reqCtx .IncomingModelName , ok = requestBodyMap ["model" ].(string )
115121 if ! ok {
116- return errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request body" }
122+ return nil , errutil.Error {Code : errutil .BadRequest , Msg : "model not found in request body" }
117123 }
118124 if reqCtx .TargetModelName == "" {
119125 // Default to incoming model name
120126 reqCtx .TargetModelName = reqCtx .IncomingModelName
121127 }
122128 reqCtx .Request .Body ["model" ] = reqCtx .TargetModelName
123- return nil
129+ return reqCtx , nil
124130}
125131
126132// HandleRequest orchestrates the request lifecycle.
@@ -129,7 +135,7 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
129135 logger := log .FromContext (ctx )
130136
131137 // Resolve target model and update req context.
132- err := d .resolveTargetModel (reqCtx )
138+ reqCtx , err := d .resolveTargetModel (reqCtx )
133139 if err != nil {
134140 return reqCtx , err
135141 }
@@ -161,15 +167,13 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo
161167 if len (candidatePods ) == 0 {
162168 return reqCtx , errutil.Error {Code : errutil .ServiceUnavailable , Msg : "failed to find candidate pods for serving the request" }
163169 }
164- // TODO(rahulgurnani/lukevandrie): Perhaps, refactor/implement Admit plugin for Admission control.
165170 if err := d .admissionController .Admit (ctx , reqCtx , candidatePods , * infObjective .Spec .Priority ); err != nil {
166171 logger .V (logutil .DEFAULT ).Info ("Request rejected by admission control" , "error" , err )
167172 return reqCtx , err
168173 }
169174 snapshotOfCandidatePods := d .toSchedulerPodMetrics (candidatePods )
170175
171- // Prepare per request data
172- // TODO(rahulgurnani): Add retries and timeout in the preparedata step.
176+ // Prepare per request data by running PrepareData plugins.
173177 d .runPrepareDataPlugins (ctx , reqCtx .SchedulingRequest , snapshotOfCandidatePods )
174178
175179 // Run admit request plugins
@@ -343,14 +347,45 @@ func (d *Director) runPreRequestPlugins(ctx context.Context, request *scheduling
343347 }
344348}
345349
350+ // prepareData runs the PrepareData plugin with retries and timeout.
351+ func prepareData (plugin PrepareData , ctx context.Context , request * schedulingtypes.LLMRequest , pods []types.Pod ) {
352+ currentTimeout := prepareDataTimeout
353+ for i := 0 ; i <= prepareDataMaxRetries ; i ++ {
354+ done := make (chan struct {})
355+ go func () {
356+ defer close (done )
357+ plugin .PrepareData (ctx , request , pods )
358+ }()
359+
360+ select {
361+ case <- done :
362+ // Plugin executed successfully
363+ return
364+ case <- time .After (currentTimeout ):
365+ log .FromContext (ctx ).V (logutil .DEBUG ).Info ("PrepareData plugin timed out, retrying..." , "plugin" , plugin .TypedName (), "retry" , i + 1 , "timeout" , currentTimeout )
366+ if i == prepareDataMaxRetries {
367+ log .FromContext (ctx ).Error (nil , "PrepareData plugin failed after multiple retries" , "plugin" , plugin .TypedName ())
368+ return
369+ }
370+ }
371+ }
372+ }
373+
346374func (d * Director ) runPrepareDataPlugins (ctx context.Context ,
347375 request * schedulingtypes.LLMRequest , pods []types.Pod ) {
348376 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
377+ // Parallely execute PrepareData for all the plugins. Some plugins might take time to prepare data e.g. latency predictor.
378+ // Failure in any prepareData doesn't block the request processing.
379+ var wg sync.WaitGroup
349380 for _ , plugin := range d .requestControlPlugins .prepareDataPlugins {
350381 loggerDebug .Info ("Running PrepareData plugin" , "plugin" , plugin .TypedName ())
351- plugin .PrepareData (ctx , request , pods )
352- loggerDebug .Info ("Completed running PrepareData plugin successfully" , "plugin" , plugin .TypedName ())
382+ wg .Add (1 )
383+ go func (p PrepareData ) {
384+ defer wg .Done ()
385+ prepareData (p , ctx , request , pods )
386+ }(plugin )
353387 }
388+ wg .Wait ()
354389}
355390
356391func (d * Director ) runAdmitRequestPlugins (ctx context.Context ,
0 commit comments