Skip to content

Commit a7c37f9

Browse files
shanshanptliutongxuan
authored andcommitted
[Serving] Fix warmup failed bug when user set warmup file path. (#507)
1 parent 70bd15e commit a7c37f9

File tree

6 files changed

+124
-51
lines changed

6 files changed

+124
-51
lines changed

Diff for: serving/processor/serving/BUILD

+21-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ tf_cc_shared_object(
2727
srcs = ["processor.cc",
2828
"processor.h",],
2929
deps = [
30-
"model_serving"],
30+
"model_serving",
31+
],
3132
)
3233

3334
cc_library(
@@ -130,6 +131,23 @@ cc_test(
130131
"@com_google_googletest//:gtest_main",],
131132
)
132133

134+
cc_library(
135+
name = "message_coding",
136+
srcs = ["message_coding.cc",],
137+
hdrs = ["message_coding.h",],
138+
deps = [
139+
"//tensorflow/core:protos_all_cc",
140+
"//tensorflow/core:framework",
141+
"//tensorflow/core:core_cpu",
142+
"//tensorflow/core:lib",
143+
"//serving/processor/framework:model_version",
144+
"//serving/processor/storage:model_store",
145+
"model_message",
146+
"predict_proto_cc",
147+
"utils",
148+
],
149+
)
150+
133151
cc_library(
134152
name = "model_instance",
135153
srcs = ["model_instance.cc",],
@@ -144,6 +162,7 @@ cc_library(
144162
"//serving/processor/framework:graph_optimizer",
145163
"//serving/processor/framework:model_version",
146164
"//serving/processor/storage:model_store",
165+
":message_coding",
147166
"model_config",
148167
"model_partition",
149168
"model_session",
@@ -156,10 +175,8 @@ cc_library(
156175
cc_library(
157176
name = "model_serving",
158177
srcs = ["model_serving.cc",
159-
"message_coding.cc",
160178
"model_impl.cc",],
161179
hdrs = ["model_serving.h",
162-
"message_coding.h",
163180
"model_impl.h",],
164181
deps = [
165182
"//tensorflow/core:protos_all_cc",
@@ -168,6 +185,7 @@ cc_library(
168185
"//tensorflow/core:lib",
169186
"//serving/processor/framework:model_version",
170187
"//serving/processor/storage:model_store",
188+
":message_coding",
171189
"model_config",
172190
"model_session",
173191
"model_message",

Diff for: serving/processor/serving/message_coding.cc

+17-6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@ ProtoBufParser::ProtoBufParser(int thread_num) {
99
thread_num));
1010
}
1111

12-
Status ProtoBufParser::ParseRequestFromBuf(
13-
const void* input_data, int input_size, Call& call,
14-
const SignatureInfo* signature_info) {
15-
eas::PredictRequest request;
16-
request.ParseFromArray(input_data, input_size);
17-
12+
Status ProtoBufParser::ParseRequest(
13+
const eas::PredictRequest& request,
14+
const SignatureInfo* signature_info, Call& call) {
1815
for (auto& input : request.inputs()) {
1916
if (signature_info->input_key_idx.find(input.first) ==
2017
signature_info->input_key_idx.end()) {
@@ -44,6 +41,20 @@ Status ProtoBufParser::ParseRequestFromBuf(
4441
return Status::OK();
4542
}
4643

44+
Status ProtoBufParser::ParseRequestFromBuf(
45+
const void* input_data, int input_size, Call& call,
46+
const SignatureInfo* signature_info) {
47+
eas::PredictRequest request;
48+
bool success = request.ParseFromArray(input_data, input_size);
49+
if (!success) {
50+
LOG(ERROR) << "Parse request from array failed, input_data: " << input_data
51+
<< ", input_size: " << input_size;
52+
return Status(errors::Code::INVALID_ARGUMENT, "Please check the input data.");
53+
}
54+
55+
return ParseRequest(request, signature_info, call);
56+
}
57+
4758
Status ProtoBufParser::ParseResponseToBuf(
4859
const Call& call, void** output_data, int* output_size,
4960
const SignatureInfo* signature_info) {

Diff for: serving/processor/serving/message_coding.h

+15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "tensorflow/core/lib/core/status.h"
55
#include "tensorflow/core/lib/core/threadpool.h"
6+
#include "serving/processor/serving/predict.pb.h"
67

78
namespace tensorflow {
89
namespace processor {
@@ -19,6 +20,10 @@ class IParser {
1920
const void* input_data, int input_size, Call& call,
2021
const SignatureInfo* info) = 0;
2122

23+
virtual Status ParseRequest(
24+
const eas::PredictRequest& request,
25+
const SignatureInfo* signature_info, Call& call) = 0;
26+
2227
virtual Status ParseResponseToBuf(
2328
const Call& call, void** output_data,
2429
int* output_size, const SignatureInfo* info) = 0;
@@ -52,6 +57,10 @@ class ProtoBufParser : public IParser {
5257
const void* input_data, int input_size,
5358
Call& call, const SignatureInfo* info) override;
5459

60+
Status ParseRequest(
61+
const eas::PredictRequest& request,
62+
const SignatureInfo* signature_info, Call& call) override;
63+
5564
Status ParseResponseToBuf(
5665
const Call& call, void** output_data,
5766
int* output_size, const SignatureInfo* info) override;
@@ -83,6 +92,12 @@ class FlatBufferParser : public IParser {
8392
return Status::OK();
8493
}
8594

95+
Status ParseRequest(
96+
const eas::PredictRequest& request,
97+
const SignatureInfo* signature_info, Call& call) override {
98+
return Status::OK();
99+
}
100+
86101
Status ParseResponseToBuf(
87102
const Call& call, void** output_data,
88103
int* output_size, const SignatureInfo* info) override {

Diff for: serving/processor/serving/model_instance.cc

+67-40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <fstream>
2+
#include "serving/processor/serving/message_coding.h"
23
#include "serving/processor/serving/model_instance.h"
34
#include "serving/processor/serving/model_partition.h"
45
#include "serving/processor/serving/model_session.h"
@@ -24,6 +25,7 @@ namespace processor {
2425
namespace {
2526
constexpr int _60_Seconds = 60;
2627
constexpr int MAX_TRY_COUNT = 10;
28+
constexpr int WARMUP_COUNT = 5;
2729

2830
Tensor CreateTensor(const TensorInfo& tensor_info) {
2931
auto real_ts = tensor_info.tensor_shape();
@@ -71,44 +73,35 @@ Tensor CreateTensor(const TensorInfo& tensor_info) {
7173
return tensor;
7274
}
7375

74-
Call CreateWarmupParams(SignatureDef& sig_def) {
75-
Call call;
76+
Status CreateWarmupParams(SignatureDef& sig_def, Call* call) {
7677
for (auto it : sig_def.inputs()) {
7778
const auto& tensor = CreateTensor(it.second);
78-
call.request.inputs.emplace_back(it.second.name(), tensor);
79+
call->request.inputs.emplace_back(it.second.name(), tensor);
7980
}
8081

8182
for (auto it : sig_def.outputs()) {
82-
call.request.output_tensor_names.emplace_back(it.second.name());
83+
call->request.output_tensor_names.emplace_back(it.second.name());
8384
}
8485

85-
return call;
86+
return Status::OK();
8687
}
8788

88-
Call CreateWarmupParams(SignatureDef& sig_def,
89-
const std::string& warmup_file_name) {
89+
Status CreateWarmupParams(SignatureDef& sig_def,
90+
const std::string& warmup_file_name,
91+
Call* call, IParser* parser,
92+
const SignatureInfo& signature_info) {
9093
// Parse warmup file
9194
eas::PredictRequest request;
9295
std::fstream input(warmup_file_name, std::ios::in | std::ios::binary);
93-
request.ParseFromIstream(&input);
94-
input.close();
95-
96-
Call call;
97-
for (auto& input : request.inputs()) {
98-
call.request.inputs.emplace_back(input.first,
99-
util::Proto2Tensor(input.second));
100-
}
101-
102-
call.request.output_tensor_names =
103-
std::vector<std::string>(request.output_filter().begin(),
104-
request.output_filter().end());
105-
106-
// User need to set fetches
107-
if (call.request.output_tensor_names.size() == 0) {
108-
LOG(FATAL) << "warmup file must be contain fetches.";
96+
bool success = request.ParseFromIstream(&input);
97+
if (!success) {
98+
LOG(ERROR) << "Read warmp file failed: " << warmup_file_name;
99+
return Status(error::Code::INTERNAL,
100+
"Read warmp file failed, please check warmp file path");
109101
}
102+
input.close();
110103

111-
return call;
104+
return parser->ParseRequest(request, &signature_info, *call);
112105
}
113106

114107
bool ShouldWarmup(SignatureDef& sig_def) {
@@ -264,6 +257,7 @@ Status LocalSessionInstance::Init(ModelConfig* config,
264257
{kSavedModelTagServe}, &meta_graph_def_));
265258

266259
warmup_file_name_ = config->warmup_file_name;
260+
parser_ = ParserFactory::GetInstance(config->serialize_protocol, 4);
267261

268262
GraphOptimizerOption option;
269263
option.native_tf_mode = true;
@@ -352,21 +346,38 @@ Status LocalSessionInstance::Warmup(
352346
return Status::OK();
353347
}
354348

349+
LOG(INFO) << "Try to warmup model: " << warmup_file_name_;
350+
Status s;
355351
Call call;
356352
if (warmup_file_name_.empty()) {
357-
call = CreateWarmupParams(model_signature_.second);
353+
s = CreateWarmupParams(model_signature_.second, &call);
358354
} else {
359-
call = CreateWarmupParams(model_signature_.second,
360-
warmup_file_name_);
355+
s = CreateWarmupParams(model_signature_.second,
356+
warmup_file_name_, &call,
357+
parser_, signature_info_);
358+
}
359+
if (!s.ok()) {
360+
LOG(ERROR) << "Create warmup params failed, warmup will be canceled.";
361+
return s;
361362
}
362363

363-
if (warmup_session) {
364-
return warmup_session->LocalPredict(
365-
call.request, call.response);
364+
int left_try_count = WARMUP_COUNT;
365+
while (left_try_count > 0) {
366+
if (warmup_session) {
367+
s = warmup_session->LocalPredict(
368+
call.request, call.response);
369+
} else {
370+
s = session_mgr_->LocalPredict(
371+
call.request, call.response);
372+
}
373+
if (!s.ok()) return s;
374+
375+
--left_try_count;
376+
call.response.outputs.clear();
366377
}
378+
LOG(INFO) << "Warmup model successful: " << warmup_file_name_;
367379

368-
return session_mgr_->LocalPredict(
369-
call.request, call.response);
380+
return Status::OK();
370381
}
371382

372383
std::string LocalSessionInstance::DebugString() {
@@ -474,6 +485,7 @@ Status RemoteSessionInstance::Init(ModelConfig* model_config,
474485
backup_storage_ = new FeatureStoreMgr(&backup_model_config);
475486

476487
warmup_file_name_ = model_config->warmup_file_name;
488+
parser_ = ParserFactory::GetInstance(model_config->serialize_protocol, 4);
477489

478490
// set active flag
479491
serving_storage_->SetStorageActiveStatus(active);
@@ -534,21 +546,36 @@ Status RemoteSessionInstance::Warmup(
534546
return Status::OK();
535547
}
536548

549+
Status s;
537550
Call call;
538551
if (warmup_file_name_.empty()) {
539-
call = CreateWarmupParams(model_signature_.second);
552+
s = CreateWarmupParams(model_signature_.second, &call);
540553
} else {
541-
call = CreateWarmupParams(model_signature_.second,
542-
warmup_file_name_);
554+
s = CreateWarmupParams(model_signature_.second,
555+
warmup_file_name_, &call,
556+
parser_, signature_info_);
557+
}
558+
if (!s.ok()) {
559+
LOG(ERROR) << "Create warmup params failed, warmup will be canceled.";
560+
return s;
543561
}
544562

545-
if (warmup_session) {
546-
return warmup_session->Predict(
547-
call.request, call.response);
563+
int left_try_count = WARMUP_COUNT;
564+
while (left_try_count > 0) {
565+
if (warmup_session) {
566+
s = warmup_session->LocalPredict(
567+
call.request, call.response);
568+
} else {
569+
s = session_mgr_->LocalPredict(
570+
call.request, call.response);
571+
}
572+
if (!s.ok()) return s;
573+
574+
--left_try_count;
575+
call.response.outputs.clear();
548576
}
549577

550-
return session_mgr_->Predict(
551-
call.request, call.response);
578+
return Status::OK();
552579
}
553580

554581
Status RemoteSessionInstance::FullModelUpdate(

Diff for: serving/processor/serving/model_instance.h

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class ModelStore;
2222
class ModelSession;
2323
class ModelSessionMgr;
2424
class IFeatureStoreMgr;
25+
class IParser;
2526

2627
class LocalSessionInstance {
2728
public:
@@ -59,6 +60,7 @@ class LocalSessionInstance {
5960
SignatureInfo signature_info_;
6061

6162
std::string warmup_file_name_;
63+
IParser* parser_ = nullptr;
6264

6365
ModelSessionMgr* session_mgr_ = nullptr;
6466
SessionOptions* session_options_ = nullptr;
@@ -108,6 +110,7 @@ class RemoteSessionInstance {
108110
SignatureInfo signature_info_;
109111

110112
std::string warmup_file_name_;
113+
IParser* parser_ = nullptr;
111114

112115
ModelSessionMgr* session_mgr_ = nullptr;
113116
SessionOptions* session_options_ = nullptr;

Diff for: serving/processor/serving/model_serving.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ Status Model::Init(const char* model_config) {
2424
}
2525

2626
if (!config->warmup_file_name.empty()) {
27-
config->warmup_file_name =
28-
model_entry_ + config->warmup_file_name;
27+
LOG(INFO) << "User set warmup file: " << config->warmup_file_name;
2928
}
3029

3130
parser_ = ParserFactory::GetInstance(config->serialize_protocol,

0 commit comments

Comments
 (0)