Skip to content

Commit 3eee01a

Browse files
committed
Fix stack exhaustion DoS in json parsing via iterative shape traversal
1 parent 14a8723 commit 3eee01a

2 files changed

Lines changed: 44 additions & 9 deletions

File tree

tensorflow_serving/util/json_tensor.cc

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,25 @@ Status AddValueToTensor(const rapidjson::Value& val, DataType dtype,
346346
// `val` can be scalar or list or list of lists with arbitrary nesting. If a
347347
// scalar (non array) is passed, we do not add dimension info to shape (as
348348
// scalars do not have a dimension).
349-
void GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) {
350-
if (!val.IsArray()) return;
351-
const auto size = val.Size();
352-
shape->add_dim()->set_size(size);
353-
if (size > 0) {
354-
GetDenseTensorShape(val[0], shape);
349+
constexpr int kMaxTensorRank = 254;
350+
351+
Status GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) {
352+
const rapidjson::Value* curr = &val;
353+
int depth = 0;
354+
while (curr->IsArray()) {
355+
if (++depth > kMaxTensorRank) {
356+
return errors::InvalidArgument(
357+
"Tensor rank exceeds maximum allowed rank: ", kMaxTensorRank);
358+
}
359+
const auto size = curr->Size();
360+
shape->add_dim()->set_size(size);
361+
if (size > 0) {
362+
curr = &((*curr)[0]);
363+
} else {
364+
break;
365+
}
355366
}
367+
return OkStatus();
356368
}
357369

358370
bool IsValBase64Object(const rapidjson::Value& val) {
@@ -392,6 +404,10 @@ Status JsonDecodeBase64Object(const rapidjson::Value& val,
392404
Status FillTensorProto(const rapidjson::Value& val, int level, DataType dtype,
393405
int* val_count, TensorProto* tensor) {
394406
const auto rank = tensor->tensor_shape().dim_size();
407+
if (rank > kMaxTensorRank) {
408+
return errors::InvalidArgument(
409+
"Tensor rank ", rank, " exceeds maximum allowed rank ", kMaxTensorRank);
410+
}
395411
if (!val.IsArray()) {
396412
// DOM tree for a (dense) tensor will always have all values
397413
// at same (leaf) level equal to the rank of the tensor.
@@ -453,7 +469,7 @@ Status AddInstanceItem(const rapidjson::Value& item, const string& name,
453469
const auto dtype = tensorinfo_map.at(name).dtype();
454470
auto* tensor = &(*tensor_map)[name];
455471
tensor->mutable_tensor_shape()->Clear();
456-
GetDenseTensorShape(item, tensor->mutable_tensor_shape());
472+
TF_RETURN_IF_ERROR(GetDenseTensorShape(item, tensor->mutable_tensor_shape()));
457473
TF_RETURN_IF_ERROR(
458474
FillTensorProto(item, 0 /* level */, dtype, &size, tensor));
459475
if (!size_map->count(name)) {
@@ -623,7 +639,7 @@ Status FillTensorMapFromInputsMap(
623639

624640
auto* tensor = &(*tensor_map)[tensorinfo_map.begin()->first];
625641
tensor->set_dtype(tensorinfo_map.begin()->second.dtype());
626-
GetDenseTensorShape(val, tensor->mutable_tensor_shape());
642+
TF_RETURN_IF_ERROR(GetDenseTensorShape(val, tensor->mutable_tensor_shape()));
627643
int unused_size = 0;
628644
TF_RETURN_IF_ERROR(FillTensorProto(val, 0 /* level */, tensor->dtype(),
629645
&unused_size, tensor));
@@ -639,7 +655,7 @@ Status FillTensorMapFromInputsMap(
639655
auto* tensor = &(*tensor_map)[name];
640656
tensor->set_dtype(dtype);
641657
tensor->mutable_tensor_shape()->Clear();
642-
GetDenseTensorShape(item->value, tensor->mutable_tensor_shape());
658+
TF_RETURN_IF_ERROR(GetDenseTensorShape(item->value, tensor->mutable_tensor_shape()));
643659
int unused_size = 0;
644660
TF_RETURN_IF_ERROR(FillTensorProto(item->value, 0 /* level */, dtype,
645661
&unused_size, tensor));

tensorflow_serving/util/json_tensor_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,25 @@ TEST(JsontensorTest, DeeplyNestedMalformed) {
118118
EXPECT_THAT(status.message(), HasSubstr("key must be a string value"));
119119
}
120120

121+
TEST(JsontensorTest, DeeplyNestedTensorValueExceedsMaxRank) {
122+
TensorInfoMap infomap;
123+
ASSERT_TRUE(
124+
TextFormat::ParseFromString("dtype: DT_INT32", &infomap["default"]));
125+
126+
PredictRequest req;
127+
JsonPredictRequestFormat format;
128+
std::string json_req = R"({"instances":)";
129+
int depth = 300; // exceeds kMaxTensorRank (254)
130+
json_req.append(depth, '[');
131+
json_req.append("1");
132+
json_req.append(depth, ']');
133+
json_req.append("}");
134+
auto status =
135+
FillPredictRequestFromJson(json_req, getmap(infomap), &req, &format);
136+
ASSERT_FALSE(status.ok());
137+
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
138+
}
139+
121140
TEST(JsontensorTest, MixedInputForFloatTensor) {
122141
TensorInfoMap infomap;
123142
ASSERT_TRUE(

0 commit comments

Comments
 (0)