Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit d40ac6c

Browse files
author
jamieposton
committed
Added functionality for querying by version number/version label.
1 parent 2ea8ec1 commit d40ac6c

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

tensor2tensor/serving/query.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
"cloud_mlengine_model_version", None,
4949
"Version of the model to use. If None, requests will be "
5050
"sent to the default version.")
51+
flags.DEFINE_string("version", None, "Version of the model to use.")
52+
flags.DEFINE_string("version_label", None, "Label of the model to use.")
5153

5254

5355
def validate_flags():
@@ -72,7 +74,9 @@ def make_request_fn():
7274
request_fn = serving_utils.make_grpc_request_fn(
7375
servable_name=FLAGS.servable_name,
7476
server=FLAGS.server,
75-
timeout_secs=FLAGS.timeout_secs)
77+
timeout_secs=FLAGS.timeout_secs,
78+
version_label=FLAGS.version_label,
79+
version=FLAGS.version)
7680
return request_fn
7781

7882

tensor2tensor/serving/serving_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,20 @@ def _decode(output_ids, output_decoder):
105105

106106

107107

108-
def make_grpc_request_fn(servable_name, server, timeout_secs):
108+
def make_grpc_request_fn(servable_name, server, timeout_secs, version_label, version):
109109
"""Wraps function to make grpc requests with runtime args."""
110110
stub = _create_stub(server)
111111

112112
def _make_grpc_request(examples):
113113
"""Builds and sends request to TensorFlow model server."""
114114
request = predict_pb2.PredictRequest()
115115
request.model_spec.name = servable_name
116+
117+
if version_label is not None:
118+
request.model_spec.version_label = version_label
119+
elif version is not None:
120+
request.model_spec.version = version
121+
116122
request.inputs["input"].CopyFrom(
117123
tf.make_tensor_proto(
118124
[ex.SerializeToString() for ex in examples], shape=[len(examples)]))

0 commit comments

Comments
 (0)