diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc
index 2b83500692d..e83cb919b14 100644
--- a/tensorflow_serving/model_servers/main.cc
+++ b/tensorflow_serving/model_servers/main.cc
@@ -192,7 +192,11 @@ int main(int argc, char** argv) {
"EXPERIMENTAL; CAN BE REMOVED ANYTIME! Load and use "
"TensorFlow Lite model from `model.tflite` file in "
"SavedModel directory instead of the TensorFlow model "
- "from `saved_model.pb` file.")};
+ "from `saved_model.pb` file."),
+ tensorflow::Flag("total_memory_limit_megabytes",
+ &options.total_model_memory_limit_megabytes,
+ "Total model memory limit in megabytes"),
+ };
const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list);
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
diff --git a/tensorflow_serving/model_servers/server.cc b/tensorflow_serving/model_servers/server.cc
index a47dee9a893..48a849d4536 100644
--- a/tensorflow_serving/model_servers/server.cc
+++ b/tensorflow_serving/model_servers/server.cc
@@ -19,8 +19,10 @@ limitations under the License.
#include
#include
+#include
#include
#include
+#include
#include "google/protobuf/wrappers.pb.h"
#include "grpc/grpc.h"
@@ -285,7 +287,9 @@ Status Server::BuildAndStart(const Options& server_options) {
options.flush_filesystem_caches = server_options.flush_filesystem_caches;
options.allow_version_labels_for_unavailable_models =
server_options.allow_version_labels_for_unavailable_models;
-
+ options.total_model_memory_limit_bytes = std::min(
+ ((uint64)server_options.total_model_memory_limit_megabytes) << 20,
+ std::numeric_limits::max());
TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_));
// Model config polling thread must be started after the call to
diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h
index 6a71265f0c5..e5e4e17ae62 100644
--- a/tensorflow_serving/model_servers/server.h
+++ b/tensorflow_serving/model_servers/server.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_SERVING_MODEL_SERVERS_SERVER_H_
#include
+#include
#include "grpcpp/server.h"
#include "tensorflow/core/kernels/batching_util/periodic_function.h"
@@ -59,6 +60,8 @@ class Server {
tensorflow::string batching_parameters_file;
tensorflow::string model_name;
tensorflow::int32 max_num_load_retries = 5;
+ tensorflow::int64 total_model_memory_limit_megabytes =
+ std::numeric_limits::max() >> 20;
tensorflow::int64 load_retry_interval_micros = 1LL * 60 * 1000 * 1000;
tensorflow::int32 file_system_poll_wait_seconds = 1;
bool flush_filesystem_caches = true;