Skip to content

Commit 6a9d0fd

Browse files
cky9301tensorflow-copybara
authored andcommitted
Add RunOptions to Servable interface.
PiperOrigin-RevId: 550921545
1 parent 3d8e8c3 commit 6a9d0fd

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

tensorflow_serving/servables/tensorflow/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,8 @@ cc_library(
10801080
"//visibility:public",
10811081
],
10821082
deps = [
1083+
":predict_response_tensor_serialization_option",
1084+
":thread_pool_factory",
10831085
"//tensorflow_serving/apis:classification_cc_proto",
10841086
"//tensorflow_serving/apis:get_model_metadata_cc_proto",
10851087
"//tensorflow_serving/apis:inference_cc_proto",

tensorflow_serving/servables/tensorflow/mock_servable.h

+11-6
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ limitations under the License.
1616
#ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_
1717
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_
1818

19-
#include <gmock/gmock.h>
2019
#include "absl/functional/any_invocable.h"
2120
#include "absl/status/status.h"
2221
#include "tensorflow_serving/servables/tensorflow/servable.h"
22+
#include "tensorflow_serving/test_util/test_util.h"
2323

2424
namespace tensorflow {
2525
namespace serving {
@@ -31,20 +31,25 @@ class MockServable : public Servable {
3131
~MockServable() override = default;
3232

3333
MOCK_METHOD(absl::Status, Classify,
34-
(const tensorflow::serving::ClassificationRequest& request,
34+
(const tensorflow::serving::Servable::RunOptions& run_options,
35+
const tensorflow::serving::ClassificationRequest& request,
3536
tensorflow::serving::ClassificationResponse* response));
3637
MOCK_METHOD(absl::Status, Regress,
37-
(const tensorflow::serving::RegressionRequest& request,
38+
(const tensorflow::serving::Servable::RunOptions& run_options,
39+
const tensorflow::serving::RegressionRequest& request,
3840
tensorflow::serving::RegressionResponse* response));
3941
MOCK_METHOD(absl::Status, Predict,
40-
(const tensorflow::serving::PredictRequest& request,
42+
(const tensorflow::serving::Servable::RunOptions& run_options,
43+
const tensorflow::serving::PredictRequest& request,
4144
tensorflow::serving::PredictResponse* response));
4245
MOCK_METHOD(absl::Status, PredictStreamed,
43-
(const tensorflow::serving::PredictRequest& request,
46+
(const tensorflow::serving::Servable::RunOptions& run_options,
47+
const tensorflow::serving::PredictRequest& request,
4448
absl::AnyInvocable<void(tensorflow::serving::PredictResponse)>
4549
response_callback));
4650
MOCK_METHOD(absl::Status, MultiInference,
47-
(const tensorflow::serving::MultiInferenceRequest& request,
51+
(const tensorflow::serving::Servable::RunOptions& run_options,
52+
const tensorflow::serving::MultiInferenceRequest& request,
4853
tensorflow::serving::MultiInferenceResponse* response));
4954
MOCK_METHOD(absl::Status, GetModelMetadata,
5055
(const tensorflow::serving::GetModelMetadataRequest& request,

tensorflow_serving/servables/tensorflow/servable.h

+31-10
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ limitations under the License.
2929
#include "tensorflow_serving/apis/inference.pb.h"
3030
#include "tensorflow_serving/apis/predict.pb.h"
3131
#include "tensorflow_serving/apis/regression.pb.h"
32+
#include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
33+
#include "tensorflow_serving/servables/tensorflow/thread_pool_factory.h"
3234

3335
namespace tensorflow {
3436
namespace serving {
@@ -48,13 +50,27 @@ class Servable {
4850
// Returns the version associated with this servable.
4951
int64_t version() const { return version_; }
5052

51-
virtual absl::Status Classify(const ClassificationRequest& request,
53+
// RunOptions group the configuration for individual inference executions.
54+
// The per-request configuration (e.g. deadline) can be passed here.
55+
struct RunOptions {
56+
// Priority of the request. Some thread pool implementation will schedule
57+
// ops based on the priority number. Larger number means higher
58+
// priority.
59+
int64_t priority = 1;
60+
// The deadline for this request.
61+
absl::Time deadline = absl::InfiniteFuture();
62+
};
63+
64+
virtual absl::Status Classify(const RunOptions& run_options,
65+
const ClassificationRequest& request,
5266
ClassificationResponse* response) = 0;
5367

54-
virtual absl::Status Regress(const RegressionRequest& request,
68+
virtual absl::Status Regress(const RunOptions& run_options,
69+
const RegressionRequest& request,
5570
RegressionResponse* response) = 0;
5671

57-
virtual absl::Status Predict(const PredictRequest& request,
72+
virtual absl::Status Predict(const RunOptions& run_options,
73+
const PredictRequest& request,
5874
PredictResponse* response) = 0;
5975

6076
// Streamed version of `Predict`. Experimental API that is not yet part of the
@@ -67,10 +83,11 @@ class Servable {
6783
// callback invocation to be delayed. The implementation guarantees that the
6884
// callback is never called after the `PredictStreamed` method returns.
6985
virtual absl::Status PredictStreamed(
70-
const PredictRequest& request,
86+
const RunOptions& run_options, const PredictRequest& request,
7187
absl::AnyInvocable<void(PredictResponse)> response_callback) = 0;
7288

73-
virtual absl::Status MultiInference(const MultiInferenceRequest& request,
89+
virtual absl::Status MultiInference(const RunOptions& run_options,
90+
const MultiInferenceRequest& request,
7491
MultiInferenceResponse* response) = 0;
7592

7693
virtual absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
@@ -95,28 +112,32 @@ class EmptyServable : public Servable {
95112
public:
96113
EmptyServable();
97114

98-
absl::Status Classify(const ClassificationRequest& request,
115+
absl::Status Classify(const RunOptions& run_options,
116+
const ClassificationRequest& request,
99117
ClassificationResponse* response) override {
100118
return error_;
101119
}
102120

103-
absl::Status Regress(const RegressionRequest& request,
121+
absl::Status Regress(const RunOptions& run_options,
122+
const RegressionRequest& request,
104123
RegressionResponse* response) override {
105124
return error_;
106125
}
107126

108-
absl::Status Predict(const PredictRequest& request,
127+
absl::Status Predict(const RunOptions& run_options,
128+
const PredictRequest& request,
109129
PredictResponse* response) override {
110130
return error_;
111131
}
112132

113133
absl::Status PredictStreamed(
114-
const PredictRequest& request,
134+
const RunOptions& run_options, const PredictRequest& request,
115135
absl::AnyInvocable<void(PredictResponse)> response_callback) override {
116136
return error_;
117137
}
118138

119-
absl::Status MultiInference(const MultiInferenceRequest& request,
139+
absl::Status MultiInference(const RunOptions& run_options,
140+
const MultiInferenceRequest& request,
120141
MultiInferenceResponse* response) override {
121142
return error_;
122143
}

tensorflow_serving/servables/tensorflow/servable_test.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@ limitations under the License.
1515

1616
#include "tensorflow_serving/servables/tensorflow/servable.h"
1717

18-
#include <gmock/gmock.h>
19-
#include <gtest/gtest.h>
2018
#include "absl/status/status.h"
2119
#include "tensorflow_serving/apis/predict.pb.h"
20+
#include "tensorflow_serving/test_util/test_util.h"
2221

2322
namespace tensorflow {
2423
namespace serving {
2524
namespace {
2625

2726
TEST(EmptyServableTest, Predict) {
2827
PredictResponse response;
29-
EXPECT_EQ(EmptyServable().Predict(PredictRequest(), &response).code(),
28+
EXPECT_EQ(EmptyServable()
29+
.Predict(Servable::RunOptions(), PredictRequest(), &response)
30+
.code(),
3031
absl::StatusCode::kFailedPrecondition);
3132
}
3233

0 commit comments

Comments
 (0)