Skip to content

Commit ad9437b

Browse files
committed
some changes
1 parent 9dcda4a commit ad9437b

File tree

11 files changed

+498
-530
lines changed

11 files changed

+498
-530
lines changed

deep_object_detection/README.md

Lines changed: 192 additions & 79 deletions
Large diffs are not rendered by default.

deep_object_detection/config/generic_model_params.yaml

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ deep_object_detection_node:
1919
nms_iou_threshold: 0.45
2020
max_detections: 300
2121
score_activation: "sigmoid"
22-
enable_nms: true
2322
use_multi_output: false
24-
output_boxes_idx: 0
25-
output_scores_idx: 1
26-
output_classes_idx: 2
2723

2824
class_score_mode: "all_classes"
2925
class_score_start_idx: -1
@@ -41,20 +37,19 @@ deep_object_detection_node:
4137
score_idx: 4
4238
class_idx: 5
4339

44-
use_camera_sync: true
45-
camera_sync_topic: "/multi_camera_sync/multi_image_compressed"
46-
input_qos_reliability: "best_effort"
40+
input_topic: "/multi_camera_sync/multi_image_compressed"
4741
output_detections_topic: "/detections"
4842

49-
min_batch_size: 1
5043
max_batch_size: 3
51-
max_batch_latency_ms: 0
5244
queue_size: 10
53-
queue_overflow_policy: "drop_oldest"
54-
decode_failure_policy: "drop"
5545

5646
preferred_provider: "tensorrt"
5747
device_id: 0
58-
warmup_tensor_shapes: true
5948
enable_trt_engine_cache: true
6049
trt_engine_cache_path: "/tmp/deep_ros_ort_trt_cache"
50+
51+
Backend:
52+
execution_provider: "tensorrt"
53+
device_id: 0
54+
trt_engine_cache_enable: true
55+
trt_engine_cache_path: "/tmp/deep_ros_ort_trt_cache"

deep_object_detection/include/deep_object_detection/deep_object_detection_node.hpp

Lines changed: 20 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
#pragma once
1616

17-
#include <atomic>
1817
#include <deque>
1918
#include <memory>
2019
#include <mutex>
2120
#include <string>
2221
#include <vector>
2322

23+
#include <deep_msgs/msg/multi_image.hpp>
2424
#include <opencv2/core/mat.hpp>
2525
#include <rclcpp/node_options.hpp>
2626
#include <rclcpp/rclcpp.hpp>
@@ -29,8 +29,6 @@
2929
#include <sensor_msgs/msg/compressed_image.hpp>
3030
#include <std_msgs/msg/header.hpp>
3131

32-
#include <deep_msgs/msg/multi_image.hpp>
33-
3432
#include "deep_object_detection/backend_manager.hpp"
3533
#include "deep_object_detection/detection_types.hpp"
3634
#include "deep_object_detection/generic_postprocessor.hpp"
@@ -42,9 +40,9 @@ namespace deep_object_detection
4240
/**
4341
* @brief ROS2 lifecycle node for object detection using ONNX models
4442
*
45-
* This node performs object detection on images from cameras or synchronized multi-camera streams.
43+
* This node performs object detection on synchronized multi-camera streams via MultiImage messages.
4644
* It supports:
47-
* - Multiple input modes: individual camera topics or synchronized MultiImage messages
45+
* - MultiImage input: synchronized compressed images from multiple cameras
4846
* - Batch processing: groups images for efficient inference
4947
* - Multiple backends: CPU, CUDA, or TensorRT execution providers
5048
* - Configurable preprocessing: resizing, normalization, color format conversion
@@ -114,143 +112,40 @@ class DeepObjectDetectionNode : public rclcpp_lifecycle::LifecycleNode
114112
const rclcpp_lifecycle::State &) override;
115113

116114
private:
117-
/**
118-
* @brief Declare and read all ROS2 parameters
119-
*
120-
* Reads model configuration, preprocessing/postprocessing parameters, camera topics,
121-
* batch settings, and backend provider settings from ROS2 parameters.
122-
*/
123115
void declareAndReadParameters();
124-
125-
/**
126-
* @brief Setup subscriptions to individual camera compressed image topics
127-
*
128-
* Creates one subscription per camera topic in params_.camera_topics.
129-
* Each subscription calls handleCompressedImage() with the camera index.
130-
*/
131-
void setupMultiCameraSubscriptions();
132-
133-
/**
134-
* @brief Setup subscription to synchronized MultiImage topic
135-
*
136-
* Creates a single subscription to camera_sync_topic_ that receives MultiImage
137-
* messages containing synchronized compressed images from multiple cameras.
138-
*/
139-
void setupCameraSyncSubscription();
140-
141-
/**
142-
* @brief Handle incoming MultiImage message with synchronized images
143-
* @param msg MultiImage message containing multiple compressed images
144-
*
145-
* Extracts each compressed image from the MultiImage and processes them
146-
* through handleCompressedImage() with sequential camera IDs.
147-
*/
116+
void setupSubscription();
148117
void onMultiImage(const deep_msgs::msg::MultiImage::ConstSharedPtr & msg);
149-
150-
/**
151-
* @brief Handle incoming compressed image from a camera
152-
* @param msg Compressed image message
153-
* @param camera_id Camera identifier (index for multi-camera, or from MultiImage)
154-
*
155-
* Decodes the compressed image, enqueues it for batch processing.
156-
* Handles decode failures according to decode_failure_policy.
157-
*/
158-
void handleCompressedImage(const sensor_msgs::msg::CompressedImage & msg, int camera_id);
159-
160-
/**
161-
* @brief Add image to processing queue
162-
* @param image Decoded BGR image (OpenCV Mat)
163-
* @param header ROS message header with timestamp and frame_id
164-
*
165-
* Thread-safe enqueueing. Applies queue_overflow_policy if queue is full.
166-
* Tracks first image timestamp for batch timeout calculation.
167-
*/
118+
void handleCompressedImage(const sensor_msgs::msg::CompressedImage & msg);
168119
void enqueueImage(cv::Mat image, const std_msgs::msg::Header & header);
169-
170-
/**
171-
* @brief Format tensor shape vector as string for logging
172-
* @param shape Vector of dimension sizes
173-
* @return Comma-separated string representation (e.g., "1, 3, 640, 640")
174-
*/
175-
std::string formatShape(const std::vector<size_t> & shape) const;
176-
177-
/**
178-
* @brief Timer callback for batch processing
179-
*
180-
* Called periodically (every 5ms) to check if batch should be processed.
181-
* Processes batch if:
182-
* - Queue size >= min_batch_size, OR
183-
* - max_batch_latency_ms exceeded and queue not empty
184-
* Extracts up to max_batch_size images and calls processBatch().
185-
*/
186120
void onBatchTimer();
187-
188-
/**
189-
* @brief Process a batch of images through inference pipeline
190-
* @param batch Vector of queued images to process
191-
*
192-
* For each image: preprocess -> inference -> postprocess -> publish detections.
193-
* Handles multi-output models if configured. Publishes Detection2DArray messages.
194-
*/
195121
void processBatch(const std::vector<QueuedImage> & batch);
196-
197-
/**
198-
* @brief Publish detection results for a batch
199-
* @param batch_detections Detections for each image in batch
200-
* @param headers Message headers for each image (for frame_id and timestamp)
201-
* @param metas Image metadata for coordinate transformation
202-
*
203-
* Creates and publishes Detection2DArray message for each image with its detections.
204-
*/
205122
void publishDetections(
206123
const std::vector<std::vector<SimpleDetection>> & batch_detections,
207124
const std::vector<std_msgs::msg::Header> & headers,
208125
const std::vector<ImageMeta> & metas);
209-
210-
/**
211-
* @brief Load class names from file
212-
*
213-
* Reads class names from params_.model_metadata.class_names_file (one per line).
214-
* Stores in params_.class_names for use in postprocessing and message publishing.
215-
*/
216126
void loadClassNames();
217-
218-
/**
219-
* @brief Stop all subscriptions and cancel batch timer
220-
*
221-
* Clears all camera subscriptions, resets MultiImage subscription,
222-
* cancels batch timer, and clears image queue. Used in deactivate/cleanup/shutdown.
223-
*/
127+
void cleanupPartialConfiguration();
128+
void cleanupAllResources();
224129
void stopSubscriptionsAndTimer();
225130

226-
DetectionParams params_; ///< All node configuration parameters
131+
DetectionParams params_;
132+
133+
rclcpp::Subscription<deep_msgs::msg::MultiImage>::SharedPtr multi_image_sub_;
134+
std::string input_topic_;
135+
rclcpp_lifecycle::LifecyclePublisher<Detection2DArrayMsg>::SharedPtr detection_pub_;
136+
rclcpp::TimerBase::SharedPtr batch_timer_;
227137

228-
std::vector<rclcpp::Subscription<sensor_msgs::msg::CompressedImage>::SharedPtr>
229-
multi_camera_subscriptions_; ///< Subscriptions for individual camera topics
230-
rclcpp::Subscription<deep_msgs::msg::MultiImage>::SharedPtr
231-
multi_image_sub_; ///< Subscription for synchronized MultiImage messages
232-
bool use_camera_sync_{false}; ///< Whether to use MultiImage sync mode or individual topics
233-
std::string camera_sync_topic_; ///< Topic name for MultiImage messages
234-
rclcpp::Publisher<Detection2DArrayMsg>::SharedPtr detection_pub_; ///< Publisher for detection results
235-
rclcpp::TimerBase::SharedPtr batch_timer_; ///< Timer for periodic batch processing checks
138+
std::deque<QueuedImage> image_queue_;
139+
std::mutex queue_mutex_;
140+
rclcpp::CallbackGroup::SharedPtr callback_group_;
236141

237-
std::deque<QueuedImage> image_queue_; ///< Queue of images waiting for batch processing
238-
std::mutex queue_mutex_; ///< Mutex protecting image_queue_ and first_image_timestamp_
239-
std::atomic<bool> processing_{false}; ///< Flag to prevent concurrent batch processing
240-
rclcpp::Time first_image_timestamp_; ///< Timestamp of oldest image in queue (for batch timeout)
142+
size_t dropped_images_count_;
241143

242-
std::unique_ptr<ImagePreprocessor> preprocessor_; ///< Image preprocessing (resize, normalize, etc.)
243-
std::unique_ptr<GenericPostprocessor> postprocessor_; ///< Detection postprocessing (NMS, decode, etc.)
244-
std::unique_ptr<BackendManager> backend_manager_; ///< Backend plugin manager (CPU/CUDA/TensorRT)
144+
std::unique_ptr<ImagePreprocessor> preprocessor_;
145+
std::unique_ptr<GenericPostprocessor> postprocessor_;
146+
std::unique_ptr<BackendManager> backend_manager_;
245147
};
246148

247-
/**
248-
* @brief Factory function to create DeepObjectDetectionNode instance
249-
* @param options Node options for ROS2 configuration
250-
* @return Shared pointer to lifecycle node
251-
*
252-
* Used by rclcpp_components for component loading.
253-
*/
254149
std::shared_ptr<rclcpp_lifecycle::LifecycleNode> createDeepObjectDetectionNode(
255150
const rclcpp::NodeOptions & options = rclcpp::NodeOptions());
256151

deep_object_detection/include/deep_object_detection/detection_types.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ using Detection2DArrayMsg = vision_msgs::msg::Detection2DArray;
4444
namespace deep_object_detection
4545
{
4646

47-
// Image processing constants
48-
constexpr size_t RGB_CHANNELS = 3; // Number of channels in RGB/BGR images
47+
constexpr size_t RGB_CHANNELS = 3;
4948

5049
enum class Provider
5150
{
@@ -126,8 +125,8 @@ struct ImageMeta
126125

127126
struct QueuedImage
128127
{
129-
cv::Mat bgr;
130-
std_msgs::msg::Header header;
128+
cv::Mat bgr; ///< Decoded BGR image (OpenCV Mat)
129+
std_msgs::msg::Header header; ///< ROS message header with timestamp and frame_id
131130
};
132131

133132
struct PackedInput
@@ -191,12 +190,9 @@ struct DetectionParams
191190
ModelMetadata model_metadata;
192191
PreprocessingConfig preprocessing;
193192
PostprocessingConfig postprocessing;
194-
std::vector<std::string> camera_topics;
195193
std::string input_qos_reliability{"best_effort"};
196194
std::string output_detections_topic{"/detections"};
197-
int min_batch_size{1};
198195
int max_batch_size{3};
199-
int max_batch_latency_ms{0}; // 0 means no timeout (wait for min_batch_size)
200196
int queue_size{10};
201197
QueueOverflowPolicy queue_overflow_policy{QueueOverflowPolicy::DROP_OLDEST};
202198
DecodeFailurePolicy decode_failure_policy{DecodeFailurePolicy::DROP};

deep_object_detection/include/deep_object_detection/generic_postprocessor.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ class GenericPostprocessor
5454

5555
static OutputLayout detectLayout(const std::vector<size_t> & output_shape);
5656

57+
/**
58+
* @brief Auto-configure output layout based on config and optional output shape
59+
* @param output_shape Model output shape (can be empty for deferred detection)
60+
* @param layout_config Layout configuration from parameters
61+
* @return Configured OutputLayout
62+
*
63+
* Handles both manual and auto-detection modes. If auto_detect is true and output_shape
64+
* is available, automatically detects layout. Otherwise uses manual config or defers detection.
65+
*/
66+
static OutputLayout autoConfigure(const std::vector<size_t> & output_shape, const OutputLayoutConfig & layout_config);
67+
5768
std::vector<std::vector<SimpleDetection>> decode(
5869
const deep_ros::Tensor & output, const std::vector<ImageMeta> & metas) const;
5970

deep_object_detection/src/backend_manager.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ Provider BackendManager::parseProvider(const std::string & provider_str) const
9393

9494
void BackendManager::initializeBackend()
9595
{
96-
// Check CUDA availability if needed
9796
if ((provider_ == Provider::TENSORRT || provider_ == Provider::CUDA) && !isCudaRuntimeAvailable()) {
9897
std::string error = "Provider " + providerToString(provider_) +
9998
" requires CUDA runtime libraries (libcudart/libcuda) which are not available";
@@ -105,11 +104,32 @@ void BackendManager::initializeBackend()
105104
throw std::runtime_error("No plugin name for provider: " + providerToString(provider_));
106105
}
107106

108-
// Update Backend.execution_provider parameter to match the actual provider
109107
const auto provider_name = providerToString(provider_);
110-
node_.set_parameters({rclcpp::Parameter("Backend.execution_provider", provider_name)});
111108

112-
// Pass the main node directly to the plugin (plugin will read Backend.* parameters from it)
109+
if (!node_.has_parameter("Backend.execution_provider")) {
110+
node_.declare_parameter<std::string>("Backend.execution_provider", provider_name);
111+
} else {
112+
node_.set_parameters({rclcpp::Parameter("Backend.execution_provider", provider_name)});
113+
}
114+
115+
if (!node_.has_parameter("Backend.device_id")) {
116+
node_.declare_parameter<int>("Backend.device_id", params_.device_id);
117+
} else {
118+
node_.set_parameters({rclcpp::Parameter("Backend.device_id", params_.device_id)});
119+
}
120+
121+
if (!node_.has_parameter("Backend.trt_engine_cache_enable")) {
122+
node_.declare_parameter<bool>("Backend.trt_engine_cache_enable", params_.enable_trt_engine_cache);
123+
} else {
124+
node_.set_parameters({rclcpp::Parameter("Backend.trt_engine_cache_enable", params_.enable_trt_engine_cache)});
125+
}
126+
127+
if (!node_.has_parameter("Backend.trt_engine_cache_path")) {
128+
node_.declare_parameter<std::string>("Backend.trt_engine_cache_path", params_.trt_engine_cache_path);
129+
} else {
130+
node_.set_parameters({rclcpp::Parameter("Backend.trt_engine_cache_path", params_.trt_engine_cache_path)});
131+
}
132+
113133
auto node_ptr = node_.shared_from_this();
114134
plugin_holder_ = plugin_loader_->createUniqueInstance(plugin_name);
115135
plugin_holder_->initialize(node_ptr);
@@ -195,7 +215,6 @@ std::string BackendManager::providerToString(Provider provider) const
195215
}
196216
}
197217

198-
199218
void BackendManager::declareActiveProviderParameter(const std::string & value)
200219
{
201220
rcl_interfaces::msg::ParameterDescriptor desc;

0 commit comments

Comments
 (0)