diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index ad3faa70ed6af..a0a848eef0dff 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -66,6 +66,18 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); const auto slice_size = input_shape.SizeFromDimension(SafeInt(batch_dims_) + num_slice_dims); const auto num_batches = input_shape.SizeToDimension(SafeInt(batch_dims_)); + + // Validate batch dimensions to prevent division by zero + if (num_batches == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherND: input tensor batch dimensions cannot be zero"); + } + if (num_slices % num_batches != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherND: indices batch size (", num_slices, + ") is not divisible by input batch size (", num_batches, ")"); + } + const auto input_batch_stride = input_shape.SizeFromDimension(SafeInt(batch_dims_)); const auto num_slices_per_batch = num_slices / num_batches; std::vector sizes_from_slice_dims(onnxruntime::narrow(num_slice_dims)); diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc index 081b4b484a73b..a8f3b99b2b3d3 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -329,5 +329,48 @@ TEST(GatherNDOpTest, GatherND_slice_int64_t) { test.Run(); } +// Test for issue #23828: GatherND should return error instead of crashing +// when batch dimensions mismatch between input and indices +TEST(GatherNDOpTest, GatherND_batch_dims_mismatch_error) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + + // Input has 3 batches, but indices has 2 slices (indices batch size 2), which is not divisible by 3 - mismatch! + test.AddInput("data", {3, 3}, {0.f, 1.f, 2.f, 10.f, 11.f, 12.f, 20.f, 21.f, 22.f}); + test.AddInput("indices", {2, 1}, {1, 2}); + test.AddOutput("output", {2}, {0.f, 0.f}); // dummy output, won't be used + + // Force execution only on CPU + std::vector> cpu_only_ep; + cpu_only_ep.push_back(DefaultCpuExecutionProvider()); + + test.Run(OpTester::ExpectResult::kExpectFailure, + "GatherND: indices batch size (2) is not divisible by input batch size (3)", + {}, // no excluded providers needed + nullptr, // no RunOptions + &cpu_only_ep); // force CPU +} + +// Test for issue #23828: GatherND should return error when input batch dimension is zero +TEST(GatherNDOpTest, GatherND_zero_batch_dims_error) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + + // Input has 0 batches - should fail with clear error instead of division by zero + test.AddInput("data", {0, 3}, {}); + test.AddInput("indices", {2, 1}, {1, 2}); + test.AddOutput("output", {2}, {0.f, 0.f}); // dummy output, won't be used + + // Force execution only on CPU + std::vector> cpu_only_ep; + cpu_only_ep.push_back(DefaultCpuExecutionProvider()); + + test.Run(OpTester::ExpectResult::kExpectFailure, + "GatherND: input tensor batch dimensions cannot be zero", + {}, + nullptr, + &cpu_only_ep); // force CPU +} + } // namespace test } // namespace onnxruntime