1414
1515#include " deep_ort_backend_plugin/ort_backend_executor.hpp"
1616
17+ #include < onnxruntime_cxx_api.h>
18+
19+ #include < cstring>
1720#include < memory>
1821#include < stdexcept>
1922#include < string>
@@ -29,9 +32,19 @@ OrtBackendExecutor::OrtBackendExecutor()
2932{
3033 env_ = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, " deep_ort_backend" );
3134 memory_info_ = Ort::MemoryInfo::CreateCpu (OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
35+
36+ // Register our custom allocator with the environment
37+ auto custom_allocator_shared = get_ort_cpu_allocator ();
38+ auto * custom_allocator = static_cast <OrtCpuMemoryAllocator *>(custom_allocator_shared.get ());
39+ OrtStatus * status =
40+ OrtGetApiBase ()->GetApi (ORT_API_VERSION)->RegisterAllocator (*env_, custom_allocator->get_ort_allocator ());
41+ if (status != nullptr ) {
42+ OrtGetApiBase ()->GetApi (ORT_API_VERSION)->ReleaseStatus (status);
43+ // Log warning but don't fail - we can still work with default allocator
44+ }
3245}
3346
34- bool OrtBackendExecutor::load_model (const std::filesystem::path & model_path)
47+ bool OrtBackendExecutor::load_model_impl (const std::filesystem::path & model_path)
3548{
3649 if (!std::filesystem::exists (model_path)) {
3750 return false ;
@@ -42,35 +55,35 @@ bool OrtBackendExecutor::load_model(const std::filesystem::path & model_path)
4255 session_options.SetIntraOpNumThreads (1 );
4356 session_options.SetGraphOptimizationLevel (GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
4457
58+ // Configure session to use environment allocators (our custom allocator)
59+ session_options.AddConfigEntry (" session.use_env_allocators" , " 1" );
60+
4561 session_ = std::make_unique<Ort::Session>(*env_, model_path.c_str (), session_options);
4662
4763 model_path_ = model_path;
48- model_loaded_ = true ;
4964 return true ;
5065 } catch (const std::exception & e) {
51- model_loaded_ = false ;
5266 return false ;
5367 }
5468}
5569
56- deep_ros::Tensor OrtBackendExecutor::run_inference (deep_ros::Tensor input)
70+ deep_ros::Tensor OrtBackendExecutor::run_inference_impl (deep_ros::Tensor & input)
5771{
58- if (!model_loaded_) {
59- throw std::runtime_error (" No model loaded for inference" );
60- }
61-
6272 if (!session_) {
6373 throw std::runtime_error (" No ONNX session available" );
6474 }
6575
6676 try {
6777 // Convert deep_ros::DataType to ONNX tensor element type
6878 ONNXTensorElementDataType onnx_type = convert_to_onnx_type (input.dtype ());
69-
70- // Create input OrtValue that wraps the input tensor's memory (zero-copy!)
71- size_t input_size_bytes = input.size () * get_element_size (input.dtype ());
7279 std::vector<int64_t > input_shape_int64 (input.shape ().begin (), input.shape ().end ());
7380
81+ // Get our custom allocator for output binding
82+ auto custom_allocator_shared = get_ort_cpu_allocator ();
83+ auto * custom_allocator = static_cast <OrtCpuMemoryAllocator *>(custom_allocator_shared.get ());
84+
85+ // Create input tensor that wraps existing input memory (zero-copy!)
86+ size_t input_size_bytes = input.size () * get_element_size (input.dtype ());
7487 Ort::Value ort_input = Ort::Value::CreateTensor (
7588 memory_info_, input.data (), input_size_bytes, input_shape_int64.data (), input_shape_int64.size (), onnx_type);
7689
@@ -79,42 +92,38 @@ deep_ros::Tensor OrtBackendExecutor::run_inference(deep_ros::Tensor input)
7992 auto input_name = session_->GetInputNameAllocated (0 , allocator);
8093 auto output_name = session_->GetOutputNameAllocated (0 , allocator);
8194
82- // Get output shape (assuming we know it or can infer it)
83- auto output_shape = get_output_shape (input.shape ());
84-
85- // Allocate output tensor using our custom allocator
86- auto tensor_allocator = get_ort_cpu_allocator ();
87- deep_ros::Tensor output (output_shape, input.dtype (), tensor_allocator);
88-
89- // Create output OrtValue that wraps the output tensor's memory (zero-copy!)
90- size_t output_size_bytes = output.size () * get_element_size (output.dtype ());
91- std::vector<int64_t > output_shape_int64 (output.shape ().begin (), output.shape ().end ());
92-
93- Ort::Value ort_output = Ort::Value::CreateTensor (
94- memory_info_, output.data (), output_size_bytes, output_shape_int64.data (), output_shape_int64.size (), onnx_type);
95-
9695 // Create IO binding for zero-copy inference
9796 Ort::IoBinding binding (*session_);
9897 binding.BindInput (input_name.get (), ort_input);
99- binding.BindOutput (output_name.get (), ort_output);
10098
101- // Run inference with IO binding (zero-copy!)
99+ // Bind output to use our custom allocator - ONNX Runtime will allocate using our allocator
100+ binding.BindOutput (output_name.get (), custom_allocator->get_ort_memory_info ());
101+
102+ // Run inference with IO binding (zero-copy for both input and output!)
102103 Ort::RunOptions run_options;
103104 session_->Run (run_options, binding);
104105
106+ // Get output values allocated by ONNX Runtime using our custom allocator
107+ Ort::AllocatorWithDefaultOptions default_allocator;
108+ std::vector<Ort::Value> output_tensors = binding.GetOutputValues (default_allocator);
109+
110+ // Get output shape and create our tensor wrapping the ONNX-allocated memory
111+ auto output_shape = get_output_shape (input.shape ());
112+ void * output_data = output_tensors[0 ].GetTensorMutableData <void >();
113+
114+ // Create deep_ros tensor that wraps the ONNX-allocated memory (zero-copy!)
115+ deep_ros::Tensor output (output_data, output_shape, input.dtype ());
116+
105117 return output;
106118 } catch (const std::exception & e) {
107119 throw std::runtime_error (" ONNX Runtime inference failed: " + std::string (e.what ()));
108120 }
109121}
110122
111- void OrtBackendExecutor::unload_model ()
123+ void OrtBackendExecutor::unload_model_impl ()
112124{
113- if (model_loaded_) {
114- session_.reset ();
115- model_loaded_ = false ;
116- model_path_.clear ();
117- }
125+ session_.reset ();
126+ model_path_.clear ();
118127}
119128
120129std::vector<std::string> OrtBackendExecutor::supported_model_formats () const
0 commit comments