diff --git a/onnxruntime/python/onnxruntime_pybind_iobinding.cc b/onnxruntime/python/onnxruntime_pybind_iobinding.cc index b82dd1474bdf6..d960444f240a4 100644 --- a/onnxruntime/python/onnxruntime_pybind_iobinding.cc +++ b/onnxruntime/python/onnxruntime_pybind_iobinding.cc @@ -92,6 +92,12 @@ void addIoBindingMethods(pybind11::module& m) { }) // This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo .def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector& shape, int64_t data_ptr) -> void { + // String tensors require live std::string objects in the backing buffer; the raw-pointer + // overload only wraps caller-provided bytes, so binding a string tensor here would lead + // to reading/writing through uninitialized std::string storage. Reject it explicitly. + if (element_type == onnx::TensorProto::STRING) { + throw std::runtime_error("Only binding non-string Tensors is currently supported"); + } auto ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type); OrtValue ml_value; OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device); @@ -113,6 +119,15 @@ void addIoBindingMethods(pybind11::module& m) { OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device); auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num); + // See comment in the int32_t element_type overload above: string tensors are not safe + // to bind via a raw, non-owning pointer because no std::string objects are constructed + // in the caller buffer. Compare against the ONNX type enum rather than the singleton + // MLDataType pointer so the check stays correct even if the type registry returns a + // different (but equivalent) instance. + const auto* primitive_type = ml_type->AsPrimitiveDataType(); + if (primitive_type != nullptr && primitive_type->GetDataType() == onnx::TensorProto::STRING) { + throw std::runtime_error("Only binding non-string Tensors is currently supported"); + } OrtValue ml_value; Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast(data_ptr), info, ml_value); diff --git a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py index 0e0c62bba5d50..86b5473ecb4b8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py +++ b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py @@ -441,6 +441,32 @@ def test_bind_input_and_bind_output_with_ortvalues(self): # Inspect contents of output_ortvalue and make sure that it has the right contents self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) + def test_bind_input_rejects_string_tensor(self): + # Binding a string tensor via a raw, non-owning pointer is unsafe: the backing buffer + # has no live std::string objects, which previously caused out-of-bounds writes when + # the tensor was later read or destroyed. Both overloads of bind_input (ONNX int + # element_type and numpy dtype) must reject string tensors explicitly. + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) + io_binding = session.io_binding() + + # Use a real allocation just to have a valid pointer; the type check happens before + # the pointer is dereferenced. + scratch = np.zeros(4, dtype=np.uint8) + scratch_ptr = scratch.ctypes.data + + # Overload 1: int32 ONNX element type. + with self.assertRaisesRegex(RuntimeError, "Only binding non-string Tensors"): + io_binding.bind_input("X", "cpu", 0, int(TensorProto.STRING), [1], scratch_ptr) + + # Overload 2: numpy dtype. NPY_UNICODE, NPY_STRING and NPY_OBJECT all map to + # std::string in NumpyTypeToOnnxRuntimeTensorType, so each of them must be rejected. + for dtype in (np.dtype("U1"), np.dtype("S1"), np.dtype(object)): + with ( + self.subTest(dtype=dtype), + self.assertRaisesRegex(RuntimeError, "Only binding non-string Tensors"), + ): + io_binding.bind_input("X", "cpu", 0, dtype, [1], scratch_ptr) + if __name__ == "__main__": unittest.main()