Skip to content

Commit 31dc8f9

Browse files
committed
tests pass
1 parent 1bb2e82 commit 31dc8f9

File tree

10 files changed

+181
-70
lines changed

10 files changed

+181
-70
lines changed

DEVELOPING.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ This project includes VS Code dev container configurations for easy ROS2 develop
3131
Inside the container:
3232

3333
```bash
34+
# Update apt
35+
sudo apt update
36+
3437
# Update rosdep
3538
rosdep update
3639

deep_core/include/deep_core/deep_node_base.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ class DeepNodeBase : public rclcpp_lifecycle::LifecycleNode
135135

136136
/**
137137
* @brief Run inference on input tensor
138-
* @param inputs Input tensor for inference
138+
* @param inputs Input tensor for inference (note: some backends may require mutable access for zero-copy operations)
139139
* @return Output tensor from inference
140140
*/
141-
Tensor run_inference(const Tensor & inputs);
141+
Tensor run_inference(Tensor & inputs);
142142

143143
/**
144144
* @brief Check if a backend plugin is loaded

deep_core/include/deep_core/plugin_interfaces/backend_inference_executor.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@ class BackendInferenceExecutor
4545

4646
/**
4747
* @brief Run inference on input tensor
48-
* @param input Input tensor
48+
* @param input Input tensor (note: some backends may require mutable access for zero-copy operations)
4949
* @return Output tensor
5050
* @throws std::invalid_argument if input tensor is invalid
5151
* @throws std::runtime_error if no model is loaded
5252
*/
53-
Tensor run_inference(const Tensor & input);
53+
Tensor run_inference(Tensor & input);
5454

5555
/**
5656
* @brief Unload the currently loaded model
@@ -80,8 +80,9 @@ class BackendInferenceExecutor
8080

8181
/**
8282
* @brief Implementation of run_inference (to be overridden by backends)
83+
* @param input Input tensor (note: some backends may require mutable access)
8384
*/
84-
virtual Tensor run_inference_impl(const Tensor & input) = 0;
85+
virtual Tensor run_inference_impl(Tensor & input) = 0;
8586

8687
/**
8788
* @brief Implementation of unload_model (to be overridden by backends)

deep_core/package.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
<depend>pluginlib</depend>
1414
<depend>bondcpp</depend>
1515

16-
1716
<export>
1817
<build_type>ament_cmake</build_type>
1918
</export>

deep_core/src/backend_inference_executor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ bool BackendInferenceExecutor::load_model(const std::filesystem::path & model_pa
3434
return success;
3535
}
3636

37-
Tensor BackendInferenceExecutor::run_inference(const Tensor & input)
37+
Tensor BackendInferenceExecutor::run_inference(Tensor & input)
3838
{
3939
// Validate input tensor
4040
if (input.data() == nullptr) {

deep_core/src/deep_node_base.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ void DeepNodeBase::unload_model()
191191
}
192192
}
193193

194-
Tensor DeepNodeBase::run_inference(const Tensor & inputs)
194+
Tensor DeepNodeBase::run_inference(Tensor & inputs)
195195
{
196196
if (!plugin_) {
197197
throw std::runtime_error("No plugin loaded");

deep_ort_backend_plugin/include/deep_ort_backend_plugin/ort_backend_executor.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,34 +46,34 @@ class OrtBackendExecutor : public deep_ros::BackendInferenceExecutor
4646
*/
4747
~OrtBackendExecutor() override = default;
4848

49+
/**
50+
* @brief Get supported model formats
51+
* @return Vector containing "onnx"
52+
*/
53+
std::vector<std::string> supported_model_formats() const override;
54+
55+
protected:
4956
/**
5057
* @brief Load an ONNX model from file
5158
* @param model_path Path to the .onnx model file
5259
* @return true if successful, false otherwise
5360
*/
54-
bool load_model(const std::filesystem::path & model_path) override;
61+
bool load_model_impl(const std::filesystem::path & model_path) override;
5562

5663
/**
5764
* @brief Run inference using zero-copy IO binding
5865
* @param input Input tensor (must be compatible with model input)
5966
* @return Output tensor with inference results
6067
* @throws std::runtime_error if inference fails or no model loaded
6168
*/
62-
deep_ros::Tensor run_inference(deep_ros::Tensor input) override;
69+
deep_ros::Tensor run_inference_impl(deep_ros::Tensor & input) override;
6370

6471
/**
6572
* @brief Unload the currently loaded model
6673
*/
67-
void unload_model() override;
68-
69-
/**
70-
* @brief Get supported model formats
71-
* @return Vector containing "onnx"
72-
*/
73-
std::vector<std::string> supported_model_formats() const override;
74+
void unload_model_impl() override;
7475

7576
private:
76-
bool model_loaded_{false};
7777
std::filesystem::path model_path_;
7878

7979
std::unique_ptr<Ort::Env> env_;

deep_ort_backend_plugin/include/deep_ort_backend_plugin/ort_cpu_memory_allocator.hpp

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <onnxruntime_c_api.h>
18+
1719
#include <memory>
1820
#include <string>
1921

@@ -26,7 +28,8 @@ namespace deep_ort_backend
2628
* @brief ONNX Runtime optimized CPU memory allocator
2729
*
2830
* Provides CPU memory allocation optimized for ONNX Runtime operations
29-
* with proper alignment for SIMD operations.
31+
* with proper alignment for SIMD operations. Implements both deep_ros
32+
* BackendMemoryAllocator interface and OrtAllocator interface directly.
3033
*/
3134
class OrtCpuMemoryAllocator : public deep_ros::BackendMemoryAllocator
3235
{
@@ -39,7 +42,19 @@ class OrtCpuMemoryAllocator : public deep_ros::BackendMemoryAllocator
3942
/**
4043
* @brief Destructor
4144
*/
42-
~OrtCpuMemoryAllocator() override = default;
45+
~OrtCpuMemoryAllocator() override;
46+
47+
/**
48+
* @brief Get the OrtAllocator interface for use with ONNX Runtime
49+
* @return Pointer to OrtAllocator struct
50+
*/
51+
OrtAllocator * get_ort_allocator();
52+
53+
/**
54+
* @brief Get the OrtMemoryInfo for this allocator
55+
* @return Pointer to OrtMemoryInfo
56+
*/
57+
const OrtMemoryInfo * get_ort_memory_info() const;
4358

4459
/**
4560
* @brief Allocate aligned memory for CPU operations
@@ -54,41 +69,55 @@ class OrtCpuMemoryAllocator : public deep_ros::BackendMemoryAllocator
5469
*/
5570
void deallocate(void * ptr) override;
5671

72+
/**
73+
* @brief Check if this is device memory
74+
* @return false (CPU memory is host memory)
75+
*/
76+
bool is_device_memory() const override;
77+
78+
/**
79+
* @brief Get device name
80+
* @return "cpu"
81+
*/
82+
std::string device_name() const override;
83+
84+
protected:
5785
/**
5886
* @brief Copy from host memory (same as device for CPU)
5987
* @param dst Destination pointer
6088
* @param src Source pointer
6189
* @param bytes Number of bytes to copy
6290
*/
63-
void copy_from_host(void * dst, const void * src, size_t bytes) override;
91+
void copy_from_host_impl(void * dst, const void * src, size_t bytes) override;
6492

6593
/**
6694
* @brief Copy to host memory (same as device for CPU)
6795
* @param dst Destination pointer
6896
* @param src Source pointer
6997
* @param bytes Number of bytes to copy
7098
*/
71-
void copy_to_host(void * dst, const void * src, size_t bytes) override;
99+
void copy_to_host_impl(void * dst, const void * src, size_t bytes) override;
72100

73101
/**
74102
* @brief Copy between CPU memory locations
75103
* @param dst Destination pointer
76104
* @param src Source pointer
77105
* @param bytes Number of bytes to copy
78106
*/
79-
void copy_device_to_device(void * dst, const void * src, size_t bytes) override;
107+
void copy_device_to_device_impl(void * dst, const void * src, size_t bytes) override;
80108

81-
/**
82-
* @brief Check if this is device memory
83-
* @return false (CPU memory is host memory)
84-
*/
85-
bool is_device_memory() const override;
109+
private:
110+
OrtAllocator ort_allocator_;
111+
OrtMemoryInfo * ort_memory_info_;
86112

87-
/**
88-
* @brief Get device name
89-
* @return "cpu"
90-
*/
91-
std::string device_name() const override;
113+
// Store a pointer to self in a way that callbacks can access it
114+
static OrtCpuMemoryAllocator * instance_;
115+
116+
// Static callback functions for OrtAllocator interface
117+
static void * ORT_API_CALL ort_alloc(OrtAllocator * this_, size_t size);
118+
static void ORT_API_CALL ort_free(OrtAllocator * this_, void * p);
119+
static const OrtMemoryInfo * ORT_API_CALL ort_info(const OrtAllocator * this_);
120+
static void * ORT_API_CALL ort_reserve(OrtAllocator * this_, size_t size);
92121
};
93122

94123
/**

deep_ort_backend_plugin/src/ort_backend_executor.cpp

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#include "deep_ort_backend_plugin/ort_backend_executor.hpp"
1616

17+
#include <onnxruntime_cxx_api.h>
18+
19+
#include <cstring>
1720
#include <memory>
1821
#include <stdexcept>
1922
#include <string>
@@ -29,9 +32,19 @@ OrtBackendExecutor::OrtBackendExecutor()
2932
{
3033
env_ = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "deep_ort_backend");
3134
memory_info_ = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
35+
36+
// Register our custom allocator with the environment
37+
auto custom_allocator_shared = get_ort_cpu_allocator();
38+
auto * custom_allocator = static_cast<OrtCpuMemoryAllocator *>(custom_allocator_shared.get());
39+
OrtStatus * status =
40+
OrtGetApiBase()->GetApi(ORT_API_VERSION)->RegisterAllocator(*env_, custom_allocator->get_ort_allocator());
41+
if (status != nullptr) {
42+
OrtGetApiBase()->GetApi(ORT_API_VERSION)->ReleaseStatus(status);
43+
// Log warning but don't fail - we can still work with default allocator
44+
}
3245
}
3346

34-
bool OrtBackendExecutor::load_model(const std::filesystem::path & model_path)
47+
bool OrtBackendExecutor::load_model_impl(const std::filesystem::path & model_path)
3548
{
3649
if (!std::filesystem::exists(model_path)) {
3750
return false;
@@ -42,35 +55,35 @@ bool OrtBackendExecutor::load_model(const std::filesystem::path & model_path)
4255
session_options.SetIntraOpNumThreads(1);
4356
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
4457

58+
// Configure session to use environment allocators (our custom allocator)
59+
session_options.AddConfigEntry("session.use_env_allocators", "1");
60+
4561
session_ = std::make_unique<Ort::Session>(*env_, model_path.c_str(), session_options);
4662

4763
model_path_ = model_path;
48-
model_loaded_ = true;
4964
return true;
5065
} catch (const std::exception & e) {
51-
model_loaded_ = false;
5266
return false;
5367
}
5468
}
5569

56-
deep_ros::Tensor OrtBackendExecutor::run_inference(deep_ros::Tensor input)
70+
deep_ros::Tensor OrtBackendExecutor::run_inference_impl(deep_ros::Tensor & input)
5771
{
58-
if (!model_loaded_) {
59-
throw std::runtime_error("No model loaded for inference");
60-
}
61-
6272
if (!session_) {
6373
throw std::runtime_error("No ONNX session available");
6474
}
6575

6676
try {
6777
// Convert deep_ros::DataType to ONNX tensor element type
6878
ONNXTensorElementDataType onnx_type = convert_to_onnx_type(input.dtype());
69-
70-
// Create input OrtValue that wraps the input tensor's memory (zero-copy!)
71-
size_t input_size_bytes = input.size() * get_element_size(input.dtype());
7279
std::vector<int64_t> input_shape_int64(input.shape().begin(), input.shape().end());
7380

81+
// Get our custom allocator for output binding
82+
auto custom_allocator_shared = get_ort_cpu_allocator();
83+
auto * custom_allocator = static_cast<OrtCpuMemoryAllocator *>(custom_allocator_shared.get());
84+
85+
// Create input tensor that wraps existing input memory (zero-copy!)
86+
size_t input_size_bytes = input.size() * get_element_size(input.dtype());
7487
Ort::Value ort_input = Ort::Value::CreateTensor(
7588
memory_info_, input.data(), input_size_bytes, input_shape_int64.data(), input_shape_int64.size(), onnx_type);
7689

@@ -79,42 +92,38 @@ deep_ros::Tensor OrtBackendExecutor::run_inference(deep_ros::Tensor input)
7992
auto input_name = session_->GetInputNameAllocated(0, allocator);
8093
auto output_name = session_->GetOutputNameAllocated(0, allocator);
8194

82-
// Get output shape (assuming we know it or can infer it)
83-
auto output_shape = get_output_shape(input.shape());
84-
85-
// Allocate output tensor using our custom allocator
86-
auto tensor_allocator = get_ort_cpu_allocator();
87-
deep_ros::Tensor output(output_shape, input.dtype(), tensor_allocator);
88-
89-
// Create output OrtValue that wraps the output tensor's memory (zero-copy!)
90-
size_t output_size_bytes = output.size() * get_element_size(output.dtype());
91-
std::vector<int64_t> output_shape_int64(output.shape().begin(), output.shape().end());
92-
93-
Ort::Value ort_output = Ort::Value::CreateTensor(
94-
memory_info_, output.data(), output_size_bytes, output_shape_int64.data(), output_shape_int64.size(), onnx_type);
95-
9695
// Create IO binding for zero-copy inference
9796
Ort::IoBinding binding(*session_);
9897
binding.BindInput(input_name.get(), ort_input);
99-
binding.BindOutput(output_name.get(), ort_output);
10098

101-
// Run inference with IO binding (zero-copy!)
99+
// Bind output to use our custom allocator - ONNX Runtime will allocate using our allocator
100+
binding.BindOutput(output_name.get(), custom_allocator->get_ort_memory_info());
101+
102+
// Run inference with IO binding (zero-copy for both input and output!)
102103
Ort::RunOptions run_options;
103104
session_->Run(run_options, binding);
104105

106+
// Get output values allocated by ONNX Runtime using our custom allocator
107+
Ort::AllocatorWithDefaultOptions default_allocator;
108+
std::vector<Ort::Value> output_tensors = binding.GetOutputValues(default_allocator);
109+
110+
// Get output shape and create our tensor wrapping the ONNX-allocated memory
111+
auto output_shape = get_output_shape(input.shape());
112+
void * output_data = output_tensors[0].GetTensorMutableData<void>();
113+
114+
// Create deep_ros tensor that wraps the ONNX-allocated memory (zero-copy!)
115+
deep_ros::Tensor output(output_data, output_shape, input.dtype());
116+
105117
return output;
106118
} catch (const std::exception & e) {
107119
throw std::runtime_error("ONNX Runtime inference failed: " + std::string(e.what()));
108120
}
109121
}
110122

111-
void OrtBackendExecutor::unload_model()
123+
void OrtBackendExecutor::unload_model_impl()
112124
{
113-
if (model_loaded_) {
114-
session_.reset();
115-
model_loaded_ = false;
116-
model_path_.clear();
117-
}
125+
session_.reset();
126+
model_path_.clear();
118127
}
119128

120129
std::vector<std::string> OrtBackendExecutor::supported_model_formats() const

0 commit comments

Comments
 (0)