diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc index 2b83500692d..e83cb919b14 100644 --- a/tensorflow_serving/model_servers/main.cc +++ b/tensorflow_serving/model_servers/main.cc @@ -192,7 +192,11 @@ int main(int argc, char** argv) { "EXPERIMENTAL; CAN BE REMOVED ANYTIME! Load and use " "TensorFlow Lite model from `model.tflite` file in " "SavedModel directory instead of the TensorFlow model " - "from `saved_model.pb` file.")}; + "from `saved_model.pb` file."), + tensorflow::Flag("total_memory_limit_megabytes", + &options.total_model_memory_limit_megabytes, + "Total model memory limit in megabytes"), + }; const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { diff --git a/tensorflow_serving/model_servers/server.cc b/tensorflow_serving/model_servers/server.cc index a47dee9a893..48a849d4536 100644 --- a/tensorflow_serving/model_servers/server.cc +++ b/tensorflow_serving/model_servers/server.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include #include #include +#include #include "google/protobuf/wrappers.pb.h" #include "grpc/grpc.h" @@ -285,7 +287,9 @@ Status Server::BuildAndStart(const Options& server_options) { options.flush_filesystem_caches = server_options.flush_filesystem_caches; options.allow_version_labels_for_unavailable_models = server_options.allow_version_labels_for_unavailable_models; - + options.total_model_memory_limit_bytes = std::min( + ((uint64)server_options.total_model_memory_limit_megabytes) << 20, + std::numeric_limits::max()); TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_)); // Model config polling thread must be started after the call to diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h index 6a71265f0c5..e5e4e17ae62 100644 --- a/tensorflow_serving/model_servers/server.h +++ b/tensorflow_serving/model_servers/server.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_SERVING_MODEL_SERVERS_SERVER_H_ #include +#include #include "grpcpp/server.h" #include "tensorflow/core/kernels/batching_util/periodic_function.h" @@ -59,6 +60,8 @@ class Server { tensorflow::string batching_parameters_file; tensorflow::string model_name; tensorflow::int32 max_num_load_retries = 5; + tensorflow::int64 total_model_memory_limit_megabytes = + std::numeric_limits::max() >> 20; tensorflow::int64 load_retry_interval_micros = 1LL * 60 * 1000 * 1000; tensorflow::int32 file_system_poll_wait_seconds = 1; bool flush_filesystem_caches = true;