Skip to content

Commit 9399e23

Browse files
committed
adding dynamic model path reconfiguration
1 parent 52d43d7 commit 9399e23

File tree

9 files changed

+433
-22
lines changed

9 files changed

+433
-22
lines changed

deep_core/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ if(BUILD_TESTING)
7575
${DEEP_CORE_LIB}
7676
)
7777

78+
add_deep_test(test_dynamic_reconfiguration test/test_dynamic_reconfiguration.cpp
79+
LIBRARIES
80+
${DEEP_CORE_LIB}
81+
)
82+
7883
endif()
7984

8085
ament_export_targets(${PROJECT_NAME}Targets HAS_LIBRARY_TARGET)

deep_core/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ All nodes inherenting `deep_ros::DeepNodeBase` have the following settable param
3838

3939
Required parameters:
4040
- `Backend.plugin`: Plugin name (e.g., "onnxruntime_cpu")
41-
- `model_path`: Path to model file
41+
- `model_path`: Path to model file (dynamically reconfigurable on runtime,
42+
you can switch model's while the node is running!)
4243

4344
Optional parameters:
4445
- `Bond.enable`: Enable bond connections (default: false)

deep_core/include/deep_core/deep_node_base.hpp

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <vector>
2121

2222
#include <bondcpp/bond.hpp>
23+
#include <lifecycle_msgs/msg/state.hpp>
2324
#include <pluginlib/class_list_macros.hpp>
2425
#include <pluginlib/class_loader.hpp>
2526
#include <rclcpp_lifecycle/lifecycle_node.hpp>
@@ -41,19 +42,12 @@ using CallbackReturn = rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface
4142
* DeepNodeBase provides a lifecycle-managed ROS 2 node that can load and manage
4243
* deep learning backend plugins. It handles the plugin discovery, loading, model
4344
* management, and provides a simple interface for running inference.
44-
*
45-
* Users should inherit from this class and implement the lifecycle callback methods
46-
* to create custom inference nodes.
47-
*
48-
* Configuration parameters:
49-
* - Backend.plugin: Backend plugin name (e.g., "deep_ort_backend_plugin::OrtBackendPlugin")
50-
* - model_path: Path to the model file
5145
*/
5246
class DeepNodeBase : public rclcpp_lifecycle::LifecycleNode
5347
{
5448
public:
5549
/**
56-
* @brief Construct a new DeepNodeBase
50+
* @brief Construct a new DeepNodeBase - initialize base parameters
5751
* @param node_name Name of the ROS 2 node
5852
* @param options ROS 2 node options
5953
*/
@@ -172,15 +166,53 @@ class DeepNodeBase : public rclcpp_lifecycle::LifecycleNode
172166
std::shared_ptr<BackendMemoryAllocator> get_current_allocator() const;
173167

174168
private:
175-
// Final lifecycle callbacks - base handles backend, then calls user impl
169+
/**
170+
* @brief Configure lifecycle callback - retrieve parameter values,
171+
* loads plugin and model, then calls user implementation
172+
* @param state Current lifecycle state
173+
* @return Callback return status
174+
*/
176175
CallbackReturn on_configure(const rclcpp_lifecycle::State & state) final;
176+
177+
/**
178+
* @brief Activate lifecycle callback - starts bond if enabled, then calls user implementation
179+
* @param state Current lifecycle state
180+
* @return Callback return status
181+
*/
177182
CallbackReturn on_activate(const rclcpp_lifecycle::State & state) final;
183+
184+
/**
185+
* @brief Deactivate lifecycle callback - calls user implementation
186+
* @param state Current lifecycle state
187+
* @return Callback return status
188+
*/
178189
CallbackReturn on_deactivate(const rclcpp_lifecycle::State & state) final;
190+
191+
/**
192+
* @brief Cleanup lifecycle callback - unloads model/plugin, stops bond, then calls user implementation
193+
* @param state Current lifecycle state
194+
* @return Callback return status
195+
*/
179196
CallbackReturn on_cleanup(const rclcpp_lifecycle::State & state) final;
197+
198+
/**
199+
* @brief Shutdown lifecycle callback - unloads model/plugin, stops bond, then calls user implementation
200+
* @param state Current lifecycle state
201+
* @return Callback return status
202+
*/
180203
CallbackReturn on_shutdown(const rclcpp_lifecycle::State & state) final;
181204

182-
// Plugin discovery and loading
205+
/**
206+
* @brief Discover available backend plugins using pluginlib
207+
* @return Vector of plugin class names
208+
*/
183209
std::vector<std::string> discover_available_plugins();
210+
211+
/**
212+
* @brief Load a specific backend plugin library
213+
* @param plugin_name Name of the plugin class to load
214+
* @return Unique pointer to the loaded plugin instance
215+
*/
184216
pluginlib::UniquePtr<DeepBackendPlugin> load_plugin_library(const std::string & plugin_name);
185217

186218
// Plugin loader
@@ -201,6 +233,10 @@ class DeepNodeBase : public rclcpp_lifecycle::LifecycleNode
201233
// ROS parameters
202234
void declare_parameters();
203235
void setup_bond();
236+
rcl_interfaces::msg::SetParametersResult on_parameter_change(const std::vector<rclcpp::Parameter> & parameters);
237+
238+
// Parameter callback for dynamic reconfiguration
239+
rclcpp::node_interfaces::OnSetParametersCallbackHandle::SharedPtr parameter_callback_handle_;
204240
};
205241

206242
} // namespace deep_ros

deep_core/src/deep_node_base.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ DeepNodeBase::DeepNodeBase(const std::string & node_name, const rclcpp::NodeOpti
2929
plugin_loader_ =
3030
std::make_unique<pluginlib::ClassLoader<DeepBackendPlugin>>("deep_core", "deep_ros::DeepBackendPlugin");
3131
declare_parameters();
32+
33+
// Set up parameter callback for dynamic reconfiguration
34+
parameter_callback_handle_ =
35+
add_on_set_parameters_callback(std::bind(&DeepNodeBase::on_parameter_change, this, std::placeholders::_1));
3236
}
3337

3438
void DeepNodeBase::declare_parameters()
@@ -240,4 +244,54 @@ void DeepNodeBase::setup_bond()
240244
}
241245
}
242246

247+
rcl_interfaces::msg::SetParametersResult DeepNodeBase::on_parameter_change(
248+
const std::vector<rclcpp::Parameter> & parameters)
249+
{
250+
rcl_interfaces::msg::SetParametersResult result;
251+
result.successful = true;
252+
253+
for (const auto & param : parameters) {
254+
if (param.get_name() == "model_path") {
255+
// Only allow model changes when node is active for safety
256+
if (get_current_state().id() == lifecycle_msgs::msg::State::PRIMARY_STATE_ACTIVE) {
257+
std::string new_model_path = param.as_string();
258+
259+
// Reject empty model paths entirely
260+
if (new_model_path.empty()) {
261+
RCLCPP_ERROR(get_logger(), "Cannot set empty model path");
262+
result.successful = false;
263+
result.reason = "Cannot set empty model path";
264+
} else if (new_model_path != current_model_path_.string()) {
265+
RCLCPP_INFO(
266+
get_logger(),
267+
"Dynamically changing model from '%s' to '%s'",
268+
current_model_path_.c_str(),
269+
new_model_path.c_str());
270+
271+
// Unload current model
272+
unload_model();
273+
274+
// Load new model
275+
if (!load_model(new_model_path)) {
276+
RCLCPP_ERROR(get_logger(), "Failed to load new model: %s", new_model_path.c_str());
277+
result.successful = false;
278+
result.reason = "Failed to load new model: " + new_model_path;
279+
} else {
280+
RCLCPP_INFO(get_logger(), "Successfully loaded new model: %s", new_model_path.c_str());
281+
}
282+
}
283+
} else {
284+
RCLCPP_WARN(
285+
get_logger(),
286+
"Cannot change model_path when node is not active. Current state: %s",
287+
get_current_state().label().c_str());
288+
result.successful = false;
289+
result.reason = "Node must be active to change model_path";
290+
}
291+
}
292+
}
293+
294+
return result;
295+
}
296+
243297
} // namespace deep_ros

deep_core/src/tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ TensorPtr::TensorPtr(
8282
, is_view_(false)
8383
, allocator_(allocator)
8484
{
85+
if (shape_.empty()) {
86+
throw std::invalid_argument("Tensor shape cannot be empty");
87+
}
88+
8589
calculate_strides();
8690

8791
size_t total_elements = std::accumulate(shape_.begin(), shape_.end(), 1UL, std::multiplies<size_t>());
@@ -97,6 +101,10 @@ TensorPtr::TensorPtr(void * data, const std::vector<size_t> & shape, DataType dt
97101
, is_view_(false)
98102
, allocator_(nullptr)
99103
{
104+
if (shape_.empty()) {
105+
throw std::invalid_argument("Tensor shape cannot be empty");
106+
}
107+
100108
calculate_strides();
101109

102110
size_t total_elements = std::accumulate(shape_.begin(), shape_.end(), 1UL, std::multiplies<size_t>());

deep_core/test/test_deep_core.cpp

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <catch2/catch.hpp>
2121
#include <deep_core/deep_node_base.hpp>
2222
#include <deep_core/types/tensor.hpp>
23+
#include <deep_test/deep_test.hpp>
2324

2425
namespace deep_ros
2526
{
@@ -114,15 +115,12 @@ TEST_CASE("Different data types have correct sizes", "[tensor]")
114115
REQUIRE(uint8_tensor.size() == 10);
115116
}
116117

117-
TEST_CASE("Empty shape creates scalar tensor", "[tensor]")
118+
TEST_CASE("Empty shape throws exception", "[tensor]")
118119
{
119120
auto allocator = std::make_shared<MockMemoryAllocator>();
120121
std::vector<size_t> empty_shape;
121122

122-
TensorPtr tensor(empty_shape, DataType::FLOAT32, allocator);
123-
124-
REQUIRE(tensor.size() == 1); // Scalar tensor
125-
REQUIRE(tensor.shape().empty());
123+
REQUIRE_THROWS_AS(TensorPtr(empty_shape, DataType::FLOAT32, allocator), std::invalid_argument);
126124
}
127125

128126
TEST_CASE("Large shape allocation", "[tensor]")
@@ -263,6 +261,15 @@ class TestInferenceNode : public DeepNodeBase
263261
: DeepNodeBase("test_inference_node", options)
264262
{}
265263

264+
// Expose protected methods for testing
265+
using DeepNodeBase::get_backend_name;
266+
using DeepNodeBase::get_current_allocator;
267+
using DeepNodeBase::is_model_loaded;
268+
using DeepNodeBase::is_plugin_loaded;
269+
using DeepNodeBase::load_model;
270+
using DeepNodeBase::load_plugin;
271+
using DeepNodeBase::run_inference;
272+
266273
bool test_load_plugin(const std::string & plugin_name)
267274
{
268275
return load_plugin(plugin_name);
@@ -305,10 +312,73 @@ class TestInferenceNode : public DeepNodeBase
305312
}
306313
};
307314

308-
TEST_CASE("DeepNodeBase creation", "[node]")
315+
TEST_CASE_METHOD(deep_ros::test::TestExecutorFixture, "DeepNodeBase lifecycle management", "[node]")
309316
{
310-
// Skip ROS lifecycle node tests to avoid segfault
311-
REQUIRE(true);
317+
// Create a test node that inherits from DeepNodeBase
318+
auto test_node = std::make_shared<TestInferenceNode>();
319+
add_node(test_node);
320+
start_spinning();
321+
322+
SECTION("Node creation and initial state")
323+
{
324+
REQUIRE(test_node->get_name() == std::string("test_inference_node"));
325+
REQUIRE(test_node->get_current_state().id() == lifecycle_msgs::msg::State::PRIMARY_STATE_UNCONFIGURED);
326+
REQUIRE(test_node->is_plugin_loaded() == false);
327+
REQUIRE(test_node->is_model_loaded() == false);
328+
}
329+
330+
SECTION("Lifecycle transitions work correctly")
331+
{
332+
// Configure
333+
auto configure_result = test_node->configure();
334+
REQUIRE(configure_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE);
335+
336+
// Activate
337+
auto activate_result = test_node->activate();
338+
REQUIRE(activate_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_ACTIVE);
339+
340+
// Deactivate
341+
auto deactivate_result = test_node->deactivate();
342+
REQUIRE(deactivate_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE);
343+
344+
// Cleanup
345+
auto cleanup_result = test_node->cleanup();
346+
REQUIRE(cleanup_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_UNCONFIGURED);
347+
}
348+
349+
SECTION("Plugin loading functionality")
350+
{
351+
// Configure first
352+
test_node->configure();
353+
354+
// Try to load a non-existent plugin
355+
bool load_result = test_node->test_load_plugin("nonexistent_plugin");
356+
REQUIRE(load_result == false);
357+
REQUIRE(test_node->is_plugin_loaded() == false);
358+
REQUIRE(test_node->get_backend_name() == "none");
359+
REQUIRE(test_node->get_current_allocator() == nullptr);
360+
}
361+
362+
SECTION("Model loading without plugin fails")
363+
{
364+
test_node->configure();
365+
test_node->activate();
366+
367+
// Try to load model without plugin
368+
bool model_result = test_node->test_load_model("/fake/model.onnx");
369+
REQUIRE(model_result == false);
370+
REQUIRE(test_node->is_model_loaded() == false);
371+
}
372+
373+
SECTION("Backend functionality with no plugin")
374+
{
375+
test_node->configure();
376+
test_node->activate();
377+
378+
// Verify backend state when no plugin is loaded
379+
REQUIRE(test_node->get_backend_name() == "none");
380+
REQUIRE(test_node->get_current_allocator() == nullptr);
381+
}
312382
}
313383

314384
} // namespace test

0 commit comments

Comments
 (0)