Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten
const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1);
const auto slice_size = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_) + num_slice_dims);
const auto num_batches = input_shape.SizeToDimension(SafeInt<size_t>(batch_dims_));

// Validate batch dimensions to prevent division by zero
if (num_batches == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"GatherND: input tensor batch dimensions cannot be zero");
}
if (num_slices % num_batches != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"GatherND: indices batch size (", num_slices,
") is not divisible by input batch size (", num_batches, ")");
}

const auto input_batch_stride = input_shape.SizeFromDimension(SafeInt<size_t>(batch_dims_));
const auto num_slices_per_batch = num_slices / num_batches;
std::vector<int64_t> sizes_from_slice_dims(onnxruntime::narrow<size_t>(num_slice_dims));
Expand Down
35 changes: 35 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,40 @@ TEST(GatherNDOpTest, GatherND_slice_int64_t) {
test.Run();
}

// Test for issue #23828: GatherND should return error instead of crashing
// when batch dimensions mismatch between input and indices
TEST(GatherNDOpTest, GatherND_batch_dims_mismatch_error) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
// Input has 3 batches, but indices has 2 slices (indices batch size 2), which is not divisible by 3 - mismatch!
test.AddInput<float>("data", {3, 3}, {0.f, 1.f, 2.f, 10.f, 11.f, 12.f, 20.f, 21.f, 22.f});
test.AddInput<int64_t>("indices", {2, 1}, {1, 2});
test.AddOutput<float>("output", {2}, {0.f, 0.f}); // dummy output, won't be used
// Run only on CPU provider since validation logic is CPU-specific
test.Run(OpTester::ExpectResult::kExpectFailure,
"GatherND: indices batch size (2) is not divisible by input batch size (3)",
std::unordered_set<std::string>({kCudaExecutionProvider, kDnnlExecutionProvider,
kOpenVINOExecutionProvider, kTensorrtExecutionProvider,
kQnnExecutionProvider, kDmlExecutionProvider
}));
}

// Test for issue #23828: GatherND should return error when input batch dimension is zero
TEST(GatherNDOpTest, GatherND_zero_batch_dims_error) {
OpTester test("GatherND", 12, kOnnxDomain);
test.AddAttribute<int64_t>("batch_dims", 1);
// Input has 0 batches - should fail with clear error instead of division by zero
test.AddInput<float>("data", {0, 3}, {});
test.AddInput<int64_t>("indices", {2, 1}, {1, 2});
test.AddOutput<float>("output", {2}, {0.f, 0.f}); // dummy output, won't be used
// Run only on CPU provider since validation logic is CPU-specific
test.Run(OpTester::ExpectResult::kExpectFailure,
"GatherND: input tensor batch dimensions cannot be zero",
std::unordered_set<std::string>({kCudaExecutionProvider, kDnnlExecutionProvider,
kOpenVINOExecutionProvider, kTensorrtExecutionProvider,
kQnnExecutionProvider, kDmlExecutionProvider
}));
}

} // namespace test
} // namespace onnxruntime
Loading