From e80955361d00d081f251dc6427cac77e76b9105f Mon Sep 17 00:00:00 2001 From: Peng YU Date: Thu, 13 Feb 2020 13:31:24 +0100 Subject: [PATCH] Add list live model names endpoint --- tensorflow_serving/apis/get_model_status.proto | 10 ++++++++++ tensorflow_serving/apis/model_service.proto | 3 +++ .../model_servers/get_model_status_impl.cc | 17 +++++++++++++++++ .../model_servers/get_model_status_impl.h | 6 +++++- .../model_servers/get_model_status_impl_test.cc | 11 +++++++++++ 5 files changed, 46 insertions(+), 1 deletion(-) diff --git a/tensorflow_serving/apis/get_model_status.proto b/tensorflow_serving/apis/get_model_status.proto index 535eb9acb68..672cff41230 100644 --- a/tensorflow_serving/apis/get_model_status.proto +++ b/tensorflow_serving/apis/get_model_status.proto @@ -16,6 +16,16 @@ message GetModelStatusRequest { ModelSpec model_spec = 1; } +//request to get all live model names +message ListModelNamesRequest{ +} + +//response contains all live model names +message ListModelNamesResponse{ + repeated string model_names = 1; +} + + // Version number, state, and status for a single version of a model. message ModelVersionStatus { // Model version. diff --git a/tensorflow_serving/apis/model_service.proto b/tensorflow_serving/apis/model_service.proto index 29a3b077512..5230c01f6b8 100644 --- a/tensorflow_serving/apis/model_service.proto +++ b/tensorflow_serving/apis/model_service.proto @@ -21,4 +21,7 @@ service ModelService { // longer served. rpc HandleReloadConfigRequest(ReloadConfigRequest) returns (ReloadConfigResponse); + + // Lists the live model names + rpc ListModelNames(ListModelNamesRequest) returns (ListModelNamesResponse); } diff --git a/tensorflow_serving/model_servers/get_model_status_impl.cc b/tensorflow_serving/model_servers/get_model_status_impl.cc index 813e7119627..e78547c9626 100644 --- a/tensorflow_serving/model_servers/get_model_status_impl.cc +++ b/tensorflow_serving/model_servers/get_model_status_impl.cc @@ -62,6 +62,12 @@ void AddModelVersionStatusToResponse(GetModelStatusResponse* response, } // namespace +// Add model name to ListModelNamesResponse +void AddModelNameToResponse(ListModelNamesResponse* response, + const string& name){ + response->add_model_names(name); +} + Status GetModelStatusImpl::GetModelStatus(ServerCore* core, const GetModelStatusRequest& request, GetModelStatusResponse* response) { @@ -106,5 +112,16 @@ Status GetModelStatusImpl::GetModelStatusWithModelSpec( return tensorflow::Status::OK(); } +Status GetModelStatusImpl::ListModelNames(ServerCore* core, + const ListModelNamesRequest& request, + ListModelNamesResponse* response) { + + const ServableStateMonitor& monitor = *core->servable_state_monitor(); + for(const auto& servable: monitor.GetAllServableStates()) { + AddModelNameToResponse(response, servable.first); + } + return tensorflow::Status::OK(); +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow_serving/model_servers/get_model_status_impl.h b/tensorflow_serving/model_servers/get_model_status_impl.h index a814dc685cd..6c4323aa361 100644 --- a/tensorflow_serving/model_servers/get_model_status_impl.h +++ b/tensorflow_serving/model_servers/get_model_status_impl.h @@ -38,8 +38,12 @@ class GetModelStatusImpl { static Status GetModelStatusWithModelSpec( ServerCore* core, const ModelSpec& model_spec, const GetModelStatusRequest& request, GetModelStatusResponse* response); -}; + // List all live model names + static Status ListModelNames(ServerCore* core, + const ListModelNamesRequest& request, + ListModelNamesResponse* response); +}; } // namespace serving } // namespace tensorflow diff --git a/tensorflow_serving/model_servers/get_model_status_impl_test.cc b/tensorflow_serving/model_servers/get_model_status_impl_test.cc index 9002a8526c4..949be58c258 100644 --- a/tensorflow_serving/model_servers/get_model_status_impl_test.cc +++ b/tensorflow_serving/model_servers/get_model_status_impl_test.cc @@ -172,6 +172,17 @@ TEST_F(GetModelStatusImplTest, SingleVersionSuccess) { EXPECT_EQ("", response.model_version_status(0).status().error_message()); } +TEST_F(GetModelStatusImplTest, Success) { + ListModelsRequest request; + ListModelsResponse response; + // If two versions of model are managed by ServerCore, succesfully get model + // status for both versions of the model. + TF_EXPECT_OK( + GetModelStatusImpl::ListModels(GetServerCore(), request, &response) + ); + EXPECT_EQ(1, response.model_names_size()); +} + // Verifies that GetModelStatusWithModelSpec() uses the model spec override // rather than the one in the request. TEST_F(GetModelStatusImplTest, ModelSpecOverride) {