Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
const auto slice_size = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_) + num_slice_dims);
const auto num_batches = input_shape.SizeToDimension(SafeInt<size_t>(batch_dims_));

// Validate batch dimensions to prevent division by zero
if (num_batches == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"GatherND: input tensor batch dimensions cannot be zero");
}
if (num_slices % num_batches != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"GatherND: indices batch size (", num_slices,
") is not divisible by input batch size (", num_batches, ")");
}

const auto input_batch_stride = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_));
const auto num_slices_per_batch = num_slices / num_batches;
std::vector<int64_t> sizes_from_slice_dims(onnxruntime::narrow<size_t>(num_slice_dims));
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,18 @@ TEST(GatherNDOpTest, GatherND_slice_int64_t) {
test.Run();
}

// Test for issue #23828: GatherND should return error instead of crashing
// when batch dimensions mismatch between input and indices
TEST(GatherNDOpTest, GatherND_batch_dims_mismatch_error) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
// Input has 3 batches, but indices has 2 slices (indices batch size 2), which is not divisible by 3 - mismatch!
test.AddInput<float>("data", {3, 3}, {0.f, 1.f, 2.f, 10.f, 11.f, 12.f, 20.f, 21.f, 22.f});
test.AddInput<int64_t>("indices", {2, 1}, {1, 2});
test.AddOutput<float>("output", {2}, {0.f, 0.f}); // dummy output, won't be used
test.Run(OpTester::ExpectResult::kExpectFailure,
"GatherND: indices batch size (2) is not divisible by input batch size (3)");
}

} // namespace test
} // namespace onnxruntime
Loading