From 69063e71134668c9aa4a45ae0370dda45bf0324c Mon Sep 17 00:00:00 2001 From: jianpingzju Date: Wed, 2 Sep 2020 20:06:41 +0800 Subject: [PATCH] safer to unload model file --- tensorflow_serving/core/BUILD | 2 ++ .../core/aspired_version_policy.cc | 13 +++++++++++++ .../core/aspired_version_policy.h | 14 ++++++++++++++ .../core/aspired_versions_manager.h | 5 +++++ .../core/availability_preserving_policy.cc | 8 +++++++- .../model_servers/server_core.cc | 1 + .../file_system_storage_path_source.cc | 18 ++++++++++++++++++ .../file_system_storage_path_source.h | 4 ++++ 8 files changed, 64 insertions(+), 1 deletion(-) diff --git a/tensorflow_serving/core/BUILD b/tensorflow_serving/core/BUILD index e22dc4aec77..2088f85582b 100644 --- a/tensorflow_serving/core/BUILD +++ b/tensorflow_serving/core/BUILD @@ -659,6 +659,8 @@ cc_library( ":loader_harness", ":servable_id", "@com_google_absl//absl/types:optional", + "//tensorflow_serving/sources/storage_path:file_system_storage_path_source", + "//tensorflow_serving/sources/storage_path:file_system_storage_path_source_proto", "@org_tensorflow//tensorflow/core:lib", ], ) diff --git a/tensorflow_serving/core/aspired_version_policy.cc b/tensorflow_serving/core/aspired_version_policy.cc index 01cf119aa16..4ccc5146dfb 100644 --- a/tensorflow_serving/core/aspired_version_policy.cc +++ b/tensorflow_serving/core/aspired_version_policy.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow_serving/core/aspired_version_policy.h" namespace tensorflow { @@ -32,5 +34,16 @@ absl::optional AspiredVersionPolicy::GetHighestAspiredNewServableId( return highest_version_id; } +std::unordered_set +AspiredVersionPolicy::GetSpecificVersionsInConfig( + const std::string& servable_name) const { + mutex_lock l(mu_); + if (storage_path_source_ == nullptr) { + std::unordered_set empty_set; + return empty_set; + } + return storage_path_source_->GetSpecificVersionsInConfig(servable_name); +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow_serving/core/aspired_version_policy.h b/tensorflow_serving/core/aspired_version_policy.h index c0e40ee6bf8..1b38404197e 100644 --- a/tensorflow_serving/core/aspired_version_policy.h +++ b/tensorflow_serving/core/aspired_version_policy.h @@ -18,12 +18,15 @@ limitations under the License. #include #include +#include #include "absl/types/optional.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" #include "tensorflow_serving/core/loader_harness.h" #include "tensorflow_serving/core/servable_id.h" +#include "tensorflow_serving/sources/storage_path/file_system_storage_path_source.h" + namespace tensorflow { namespace serving { @@ -77,6 +80,15 @@ class AspiredVersionPolicy { virtual absl::optional GetNextAction( const std::vector& all_versions) const = 0; + void set_storage_path_source( + FileSystemStoragePathSource* storage_path_source) { + mutex_lock l(mu_); + storage_path_source_ = storage_path_source; + } + + std::unordered_set GetSpecificVersionsInConfig( + const std::string& servable_name) const; + protected: /// Returns the aspired ServableId with the highest version that matches /// kNew state, if any exists. @@ -85,6 +97,8 @@ class AspiredVersionPolicy { private: friend class AspiredVersionPolicyTest; + mutable mutex mu_; + FileSystemStoragePathSource* storage_path_source_ GUARDED_BY(mu_) = nullptr; }; inline bool operator==(const AspiredVersionPolicy::ServableAction& lhs, diff --git a/tensorflow_serving/core/aspired_versions_manager.h b/tensorflow_serving/core/aspired_versions_manager.h index d0395d72510..95f62ddf7f9 100644 --- a/tensorflow_serving/core/aspired_versions_manager.h +++ b/tensorflow_serving/core/aspired_versions_manager.h @@ -203,6 +203,11 @@ class AspiredVersionsManager : public Manager, // Source>::AspiredVersionsCallback GetAspiredVersionsCallback() override; + + void SetPolicyStoragePathSource( + FileSystemStoragePathSource* storage_path_source) { + aspired_version_policy_->set_storage_path_source(storage_path_source); + } private: friend class internal::AspiredVersionsManagerTargetImpl; diff --git a/tensorflow_serving/core/availability_preserving_policy.cc b/tensorflow_serving/core/availability_preserving_policy.cc index befc3b76139..a6e6bc8866c 100644 --- a/tensorflow_serving/core/availability_preserving_policy.cc +++ b/tensorflow_serving/core/availability_preserving_policy.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow_serving/core/availability_preserving_policy.h" #include "absl/types/optional.h" @@ -67,7 +69,11 @@ AvailabilityPreservingPolicy::GetNextAction( absl::optional version_to_unload = GetLowestServableId(unaspired_serving_versions); if (version_to_unload) { - return {{Action::kUnload, version_to_unload.value()}}; + std::unordered_set specific_versions = + GetSpecificVersionsInConfig(version_to_unload.value().name); + if (specific_versions.count(version_to_unload.value().version) == 0) { + return {{Action::kUnload, version_to_unload.value()}}; + } } } diff --git a/tensorflow_serving/model_servers/server_core.cc b/tensorflow_serving/model_servers/server_core.cc index 69ba76cb196..283df064b3d 100644 --- a/tensorflow_serving/model_servers/server_core.cc +++ b/tensorflow_serving/model_servers/server_core.cc @@ -347,6 +347,7 @@ Status ServerCore::AddModelsViaModelConfigList() { // Stow the source components. storage_path_source_and_router_ = {source.get(), router.get()}; + manager_->SetPolicyStoragePathSource(storage_path_source_and_router_->source); manager_.AddDependency(std::move(source)); if (prefix_source_adapter != nullptr) { manager_.AddDependency(std::move(prefix_source_adapter)); diff --git a/tensorflow_serving/sources/storage_path/file_system_storage_path_source.cc b/tensorflow_serving/sources/storage_path/file_system_storage_path_source.cc index db212542109..1ae5041af01 100644 --- a/tensorflow_serving/sources/storage_path/file_system_storage_path_source.cc +++ b/tensorflow_serving/sources/storage_path/file_system_storage_path_source.cc @@ -386,6 +386,24 @@ void FileSystemStoragePathSource::SetAspiredVersionsCallback( } } +std::unordered_set +FileSystemStoragePathSource::GetSpecificVersionsInConfig( + const std::string& servable_name) const { + mutex_lock l(mu_); + + std::unordered_set specific_versions; + for (const FileSystemStoragePathSourceConfig::ServableToMonitor& servable : + config_.servables()) { + if (servable.servable_name() == servable_name) { + specific_versions.insert( + servable.servable_version_policy().specific().versions().begin(), + servable.servable_version_policy().specific().versions().end()); + break; + } + } + return specific_versions; +} + Status FileSystemStoragePathSource::PollFileSystemAndInvokeCallback() { mutex_lock l(mu_); std::map>> diff --git a/tensorflow_serving/sources/storage_path/file_system_storage_path_source.h b/tensorflow_serving/sources/storage_path/file_system_storage_path_source.h index ef51009d634..b2d7d56531c 100644 --- a/tensorflow_serving/sources/storage_path/file_system_storage_path_source.h +++ b/tensorflow_serving/sources/storage_path/file_system_storage_path_source.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/types/variant.h" #include "tensorflow/core/kernels/batching_util/periodic_function.h" @@ -77,6 +78,9 @@ class FileSystemStoragePathSource : public Source { mutex_lock l(mu_); return config_; } + + std::unordered_set GetSpecificVersionsInConfig( + const std::string& servable_name) const; private: friend class internal::FileSystemStoragePathSourceTestAccess;