Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ void addIoBindingMethods(pybind11::module& m) {
})
// This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo
.def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
// String tensors require live std::string objects in the backing buffer; the raw-pointer
// overload only wraps caller-provided bytes, so binding a string tensor here would lead
// to reading/writing through uninitialized std::string storage. Reject it explicitly.
if (element_type == onnx::TensorProto::STRING) {
throw std::runtime_error("Only binding non-string Tensors is currently supported");
}
auto ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type);
OrtValue ml_value;
OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device);
Expand All @@ -113,6 +119,15 @@ void addIoBindingMethods(pybind11::module& m) {

OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device);
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
// See comment in the int32_t element_type overload above: string tensors are not safe
// to bind via a raw, non-owning pointer because no std::string objects are constructed
// in the caller buffer. Compare against the ONNX type enum rather than the singleton
// MLDataType pointer so the check stays correct even if the type registry returns a
// different (but equivalent) instance.
const auto* primitive_type = ml_type->AsPrimitiveDataType();
if (primitive_type != nullptr && primitive_type->GetDataType() == onnx::TensorProto::STRING) {
throw std::runtime_error("Only binding non-string Tensors is currently supported");
}
OrtValue ml_value;
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);

Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python_iobinding.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,32 @@ def test_bind_input_and_bind_output_with_ortvalues(self):
# Inspect contents of output_ortvalue and make sure that it has the right contents
self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy()))

def test_bind_input_rejects_string_tensor(self):
# Binding a string tensor via a raw, non-owning pointer is unsafe: the backing buffer
# has no live std::string objects, which previously caused out-of-bounds writes when
# the tensor was later read or destroyed. Both overloads of bind_input (ONNX int
# element_type and numpy dtype) must reject string tensors explicitly.
session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers())
io_binding = session.io_binding()

# Use a real allocation just to have a valid pointer; the type check happens before
# the pointer is dereferenced.
scratch = np.zeros(4, dtype=np.uint8)
scratch_ptr = scratch.ctypes.data

# Overload 1: int32 ONNX element type.
with self.assertRaisesRegex(RuntimeError, "Only binding non-string Tensors"):
io_binding.bind_input("X", "cpu", 0, int(TensorProto.STRING), [1], scratch_ptr)

# Overload 2: numpy dtype. NPY_UNICODE, NPY_STRING and NPY_OBJECT all map to
# std::string in NumpyTypeToOnnxRuntimeTensorType, so each of them must be rejected.
for dtype in (np.dtype("U1"), np.dtype("S1"), np.dtype(object)):
with (
self.subTest(dtype=dtype),
self.assertRaisesRegex(RuntimeError, "Only binding non-string Tensors"),
):
io_binding.bind_input("X", "cpu", 0, dtype, [1], scratch_ptr)


if __name__ == "__main__":
unittest.main()
Loading