diff --git a/tensorflow_serving/servables/tensorflow/predict_util.cc b/tensorflow_serving/servables/tensorflow/predict_util.cc index ca3d62611c6..54ef1d00543 100644 --- a/tensorflow_serving/servables/tensorflow/predict_util.cc +++ b/tensorflow_serving/servables/tensorflow/predict_util.cc @@ -64,7 +64,25 @@ std::set SetDifference(std::set set_a, std::set set_b) { Status VerifyRequestInputsSize(const SignatureDef& signature, const PredictRequest& request) { - if (request.inputs().size() != signature.inputs().size()) { + if (request.inputs().size() >= signature.inputs().size()) { + const std::set request_inputs = GetMapKeys(request.inputs()); + const std::set signature_inputs = GetMapKeys(signature.inputs()); + const std::set missing = + SetDifference(signature_inputs, request_inputs); + if (!missing.empty()) { + const std::set sent_extra = + SetDifference(request_inputs, signature_inputs); + return tensorflow::Status( + tensorflow::error::INVALID_ARGUMENT, + absl::StrCat( + "input size does not match signature: ", request.inputs().size(), + "!=", signature.inputs().size(), " len({", + absl::StrJoin(request_inputs, ","), "}) != len({", + absl::StrJoin(signature_inputs, ","), "}). Sent extra: {", + absl::StrJoin(sent_extra, ","), "}. Missing but required: {", + absl::StrJoin(missing, ","), "}.")); + } + } else { const std::set request_inputs = GetMapKeys(request.inputs()); const std::set signature_inputs = GetMapKeys(signature.inputs()); const std::set sent_extra = @@ -93,10 +111,10 @@ Status PreProcessPrediction(const SignatureDef& signature, std::vector* output_tensor_aliases) { TF_RETURN_IF_ERROR(VerifySignature(signature)); TF_RETURN_IF_ERROR(VerifyRequestInputsSize(signature, request)); - for (auto& input : request.inputs()) { + for (auto& input : signature.inputs()) { const string& alias = input.first; - auto iter = signature.inputs().find(alias); - if (iter == signature.inputs().end()) { + auto iter = request.inputs().find(alias); + if (iter == request.inputs().end()) { return tensorflow::Status( tensorflow::error::INVALID_ARGUMENT, strings::StrCat("input tensor alias not found in signature: ", alias, @@ -105,11 +123,11 @@ Status PreProcessPrediction(const SignatureDef& signature, "}.")); } Tensor tensor; - if (!tensor.FromProto(input.second)) { + if (!tensor.FromProto(iter->second)) { return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, "tensor parsing error: " + alias); } - inputs->emplace_back(std::make_pair(iter->second.name(), tensor)); + inputs->emplace_back(std::make_pair(input.second.name(), tensor)); } // Prepare run target. diff --git a/tensorflow_serving/util/json_tensor.cc b/tensorflow_serving/util/json_tensor.cc index 8c3c8d1ed7f..5c2556277de 100644 --- a/tensorflow_serving/util/json_tensor.cc +++ b/tensorflow_serving/util/json_tensor.cc @@ -434,10 +434,6 @@ Status AddInstanceItem(const rapidjson::Value& item, const string& name, ::google::protobuf::Map* size_map, ::google::protobuf::Map* shape_map, ::google::protobuf::Map* tensor_map) { - if (!tensorinfo_map.count(name)) { - return errors::InvalidArgument("JSON object: does not have named input: ", - name); - } int size = 0; const auto dtype = tensorinfo_map.at(name).dtype(); auto* tensor = &(*tensor_map)[name]; @@ -545,6 +541,9 @@ Status FillTensorMapFromInstancesList( std::set object_keys; for (const auto& kv : elem.GetObject()) { const string& name = kv.name.GetString(); + if (!tensorinfo_map.count(name)) { + continue; + } object_keys.insert(name); const auto status = AddInstanceItem(kv.value, name, tensorinfo_map, &size_map, &shape_map, tensor_map);