Skip to content

Commit 33a6f53

Browse files
shanshanptliutongxuan
authored andcommitted
[Serving] Fix warmup failed bug when use session_group. (#520)
1 parent a7c37f9 commit 33a6f53

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

Diff for: serving/processor/serving/model_instance.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,10 @@ Status LocalSessionInstance::Warmup(
364364
int left_try_count = WARMUP_COUNT;
365365
while (left_try_count > 0) {
366366
if (warmup_session) {
367-
s = warmup_session->LocalPredict(
367+
s = warmup_session->Warmup(
368368
call.request, call.response);
369369
} else {
370-
s = session_mgr_->LocalPredict(
370+
s = session_mgr_->Warmup(
371371
call.request, call.response);
372372
}
373373
if (!s.ok()) return s;
@@ -563,11 +563,11 @@ Status RemoteSessionInstance::Warmup(
563563
int left_try_count = WARMUP_COUNT;
564564
while (left_try_count > 0) {
565565
if (warmup_session) {
566-
s = warmup_session->LocalPredict(
567-
call.request, call.response);
566+
s = warmup_session->Warmup(
567+
call.request, call.response, false);
568568
} else {
569-
s = session_mgr_->LocalPredict(
570-
call.request, call.response);
569+
s = session_mgr_->Warmup(
570+
call.request, call.response, false);
571571
}
572572
if (!s.ok()) return s;
573573

Diff for: serving/processor/serving/model_session.cc

+48-5
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,16 @@ int ModelSession::GetServingSessionId() {
262262
}
263263

264264
Status ModelSession::Predict(Request& req, Response& resp) {
265+
return InternalPredict(req, resp, GetServingSessionId());
266+
}
267+
268+
Status ModelSession::Predict(Request& req, Response& resp,
269+
int sess_id) {
270+
return InternalPredict(req, resp, sess_id);
271+
}
272+
273+
Status ModelSession::InternalPredict(Request& req, Response& resp,
274+
int sess_id) {
265275
if (is_local_) {
266276
return Status(error::Code::INTERNAL,
267277
"Local sparse storage, please use LocalPredict.");
@@ -278,17 +288,31 @@ Status ModelSession::Predict(Request& req, Response& resp) {
278288
// TODO: which session selected to run on, add some policy here
279289
status = session_group_->Run(run_options, req.inputs,
280290
req.output_tensor_names, {}, &resp.outputs,
281-
&run_metadata, GetServingSessionId());
291+
&run_metadata, sess_id);
282292
Tracer::GetTracer()->GenTimeline(run_metadata);
283293
} else {
284294
status = session_group_->Run(req.inputs, req.output_tensor_names,
285-
{}, &resp.outputs, GetServingSessionId());
295+
{}, &resp.outputs, sess_id);
286296
}
287297
--counter_;
288298
return status;
289299
}
290300

291-
Status ModelSession::LocalPredict(Request& req, Response& resp) {
301+
Status ModelSession::LocalPredict(Request& req,
302+
Response& resp) {
303+
return InternalLocalPredict(req, resp,
304+
GetServingSessionId());
305+
}
306+
307+
Status ModelSession::LocalPredict(Request& req,
308+
Response& resp,
309+
int sess_id) {
310+
return InternalLocalPredict(req, resp, sess_id);
311+
}
312+
313+
Status ModelSession::InternalLocalPredict(Request& req,
314+
Response& resp,
315+
int sess_id) {
292316
if (!is_local_) {
293317
return Status(error::Code::INTERNAL,
294318
"Remote sparse storage, please use Predict.");
@@ -302,16 +326,31 @@ Status ModelSession::LocalPredict(Request& req, Response& resp) {
302326
// TODO: which session selected to run on, add some policy here
303327
status = session_group_->Run(run_options, req.inputs,
304328
req.output_tensor_names, {}, &resp.outputs,
305-
&run_metadata, GetServingSessionId());
329+
&run_metadata, sess_id);
306330
Tracer::GetTracer()->GenTimeline(run_metadata);
307331
} else {
308332
status = session_group_->Run(req.inputs, req.output_tensor_names,
309-
{}, &resp.outputs, GetServingSessionId());
333+
{}, &resp.outputs, sess_id);
310334
}
311335
--counter_;
312336
return status;
313337
}
314338

339+
Status ModelSession::Warmup(Request& req, Response& resp, bool local) {
340+
int N = session_group_->GetSessionNum();
341+
for (int i = 0; i < N; ++i) {
342+
Status s;
343+
if (local) {
344+
s = LocalPredict(req, resp, i);
345+
} else {
346+
s = Predict(req, resp, i);
347+
}
348+
if (!s.ok()) return s;
349+
}
350+
351+
return Status::OK();
352+
}
353+
315354
Status ModelSessionMgr::Predict(Request& req, Response& resp) {
316355
return serving_session_->Predict(req, resp);
317356
}
@@ -320,6 +359,10 @@ Status ModelSessionMgr::LocalPredict(Request& req, Response& resp) {
320359
return serving_session_->LocalPredict(req, resp);
321360
}
322361

362+
Status ModelSessionMgr::Warmup(Request& req, Response& resp, bool local) {
363+
return serving_session_->Warmup(req, resp, local);
364+
}
365+
323366
Status ModelSessionMgr::CreateModelSession(
324367
const Version& version, const char* ckpt_name,
325368
IFeatureStoreMgr* sparse_storage, bool is_incr_ckpt,

Diff for: serving/processor/serving/model_session.h

+6
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ struct ModelSession {
3333
virtual ~ModelSession();
3434

3535
Status Predict(Request& req, Response& resp);
36+
Status Predict(Request& req, Response& resp, int sess_id);
3637
Status LocalPredict(Request& req, Response& resp);
38+
Status LocalPredict(Request& req, Response& resp, int sess_id);
3739
Version GetVersion() {return version_;}
3840
void UpdateVersion(const Version& v) { version_ = v; }
3941
Session* GetSession();
42+
Status Warmup(Request& req, Response& resp, bool local=true);
4043

4144
SessionGroup* session_group_ = nullptr;
4245
SelectSessionPolicy select_session_policy_ =
@@ -54,6 +57,8 @@ struct ModelSession {
5457

5558
private:
5659
int GetServingSessionId();
60+
Status InternalPredict(Request& req, Response& resp, int sess_id);
61+
Status InternalLocalPredict(Request& req, Response& resp, int sess_id);
5762
};
5863

5964
class ModelSessionMgr {
@@ -64,6 +69,7 @@ class ModelSessionMgr {
6469

6570
Status Predict(Request& req, Response& resp);
6671
Status LocalPredict(Request& req, Response& resp);
72+
Status Warmup(Request& req, Response& resp, bool local=true);
6773

6874
Status CreateModelSession(
6975
const Version& version,

0 commit comments

Comments
 (0)