@@ -262,6 +262,16 @@ int ModelSession::GetServingSessionId() {
262
262
}
263
263
264
264
Status ModelSession::Predict (Request& req, Response& resp) {
265
+ return InternalPredict (req, resp, GetServingSessionId ());
266
+ }
267
+
268
+ Status ModelSession::Predict (Request& req, Response& resp,
269
+ int sess_id) {
270
+ return InternalPredict (req, resp, sess_id);
271
+ }
272
+
273
+ Status ModelSession::InternalPredict (Request& req, Response& resp,
274
+ int sess_id) {
265
275
if (is_local_) {
266
276
return Status (error::Code::INTERNAL,
267
277
" Local sparse storage, please use LocalPredict." );
@@ -278,17 +288,31 @@ Status ModelSession::Predict(Request& req, Response& resp) {
278
288
// TODO: which session selected to run on, add some policy here
279
289
status = session_group_->Run (run_options, req.inputs ,
280
290
req.output_tensor_names , {}, &resp.outputs ,
281
- &run_metadata, GetServingSessionId () );
291
+ &run_metadata, sess_id );
282
292
Tracer::GetTracer ()->GenTimeline (run_metadata);
283
293
} else {
284
294
status = session_group_->Run (req.inputs , req.output_tensor_names ,
285
- {}, &resp.outputs , GetServingSessionId () );
295
+ {}, &resp.outputs , sess_id );
286
296
}
287
297
--counter_;
288
298
return status;
289
299
}
290
300
291
- Status ModelSession::LocalPredict (Request& req, Response& resp) {
301
+ Status ModelSession::LocalPredict (Request& req,
302
+ Response& resp) {
303
+ return InternalLocalPredict (req, resp,
304
+ GetServingSessionId ());
305
+ }
306
+
307
+ Status ModelSession::LocalPredict (Request& req,
308
+ Response& resp,
309
+ int sess_id) {
310
+ return InternalLocalPredict (req, resp, sess_id);
311
+ }
312
+
313
+ Status ModelSession::InternalLocalPredict (Request& req,
314
+ Response& resp,
315
+ int sess_id) {
292
316
if (!is_local_) {
293
317
return Status (error::Code::INTERNAL,
294
318
" Remote sparse storage, please use Predict." );
@@ -302,16 +326,31 @@ Status ModelSession::LocalPredict(Request& req, Response& resp) {
302
326
// TODO: which session selected to run on, add some policy here
303
327
status = session_group_->Run (run_options, req.inputs ,
304
328
req.output_tensor_names , {}, &resp.outputs ,
305
- &run_metadata, GetServingSessionId () );
329
+ &run_metadata, sess_id );
306
330
Tracer::GetTracer ()->GenTimeline (run_metadata);
307
331
} else {
308
332
status = session_group_->Run (req.inputs , req.output_tensor_names ,
309
- {}, &resp.outputs , GetServingSessionId () );
333
+ {}, &resp.outputs , sess_id );
310
334
}
311
335
--counter_;
312
336
return status;
313
337
}
314
338
339
+ Status ModelSession::Warmup (Request& req, Response& resp, bool local) {
340
+ int N = session_group_->GetSessionNum ();
341
+ for (int i = 0 ; i < N; ++i) {
342
+ Status s;
343
+ if (local) {
344
+ s = LocalPredict (req, resp, i);
345
+ } else {
346
+ s = Predict (req, resp, i);
347
+ }
348
+ if (!s.ok ()) return s;
349
+ }
350
+
351
+ return Status::OK ();
352
+ }
353
+
315
354
Status ModelSessionMgr::Predict (Request& req, Response& resp) {
316
355
return serving_session_->Predict (req, resp);
317
356
}
@@ -320,6 +359,10 @@ Status ModelSessionMgr::LocalPredict(Request& req, Response& resp) {
320
359
return serving_session_->LocalPredict (req, resp);
321
360
}
322
361
362
+ Status ModelSessionMgr::Warmup (Request& req, Response& resp, bool local) {
363
+ return serving_session_->Warmup (req, resp, local);
364
+ }
365
+
323
366
Status ModelSessionMgr::CreateModelSession (
324
367
const Version& version, const char * ckpt_name,
325
368
IFeatureStoreMgr* sparse_storage, bool is_incr_ckpt,
0 commit comments