Skip to content

Commit 63202cd

Browse files
committed
Fix GatherND division by zero when batch dimensions mismatch
Fixes #23828 Added validation to check: - num_batches is not zero - num_slices is divisible by num_batches Before this fix, mismatched batch dimensions caused a crash due to division by zero.
1 parent a3e477e commit 63202cd

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

onnxruntime/core/providers/cpu/tensor/gather_nd.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten
6666
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
6767
const auto slice_size = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_) + num_slice_dims);
6868
const auto num_batches = input_shape.SizeToDimension(SafeInt<size_t>(batch_dims_));
69+
70+
// Validate batch dimensions to prevent division by zero
71+
if (num_batches == 0) {
72+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
73+
"GatherND: input tensor batch dimensions cannot be zero");
74+
}
75+
if (num_slices % num_batches != 0) {
76+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
77+
"GatherND: indices batch size (", num_slices,
78+
") must be divisible by input batch size (", num_batches, ")");
79+
}
80+
6981
const auto input_batch_stride = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_));
7082
const auto num_slices_per_batch = num_slices / num_batches;
7183
std::vector<int64_t> sizes_from_slice_dims(onnxruntime::narrow<size_t>(num_slice_dims));

onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,5 +329,18 @@ TEST(GatherNDOpTest, GatherND_slice_int64_t) {
329329
test.Run();
330330
}
331331

332+
// Test for issue #23828: GatherND should return error instead of crashing
333+
// when batch dimensions mismatch between input and indices
334+
TEST(GatherNDOpTest, GatherND_batch_dims_mismatch_error) {
335+
OpTester test("GatherND", 12, kOnnxDomain);
336+
test.AddAttribute<int64_t>("batch_dims", 1);
337+
// Input has 3 batches, but indices has 2 batches - mismatch!
338+
test.AddInput<float>("data", {3, 3}, {0.f, 1.f, 2.f, 10.f, 11.f, 12.f, 20.f, 21.f, 22.f});
339+
test.AddInput<int64_t>("indices", {2, 1}, {1, 2});
340+
test.AddOutput<float>("output", {2}, {0.f, 0.f}); // dummy output, won't be used
341+
test.Run(OpTester::ExpectResult::kExpectFailure,
342+
"indices batch size (2) must be divisible by input batch size (3)");
343+
}
344+
332345
} // namespace test
333346
} // namespace onnxruntime

0 commit comments

Comments
 (0)