|
20 | 20 | #include <catch2/catch.hpp> |
21 | 21 | #include <deep_core/deep_node_base.hpp> |
22 | 22 | #include <deep_core/types/tensor.hpp> |
| 23 | +#include <deep_test/deep_test.hpp> |
23 | 24 |
|
24 | 25 | namespace deep_ros |
25 | 26 | { |
@@ -114,15 +115,12 @@ TEST_CASE("Different data types have correct sizes", "[tensor]") |
114 | 115 | REQUIRE(uint8_tensor.size() == 10); |
115 | 116 | } |
116 | 117 |
|
117 | | -TEST_CASE("Empty shape creates scalar tensor", "[tensor]") |
| 118 | +TEST_CASE("Empty shape throws exception", "[tensor]") |
118 | 119 | { |
119 | 120 | auto allocator = std::make_shared<MockMemoryAllocator>(); |
120 | 121 | std::vector<size_t> empty_shape; |
121 | 122 |
|
122 | | - TensorPtr tensor(empty_shape, DataType::FLOAT32, allocator); |
123 | | - |
124 | | - REQUIRE(tensor.size() == 1); // Scalar tensor |
125 | | - REQUIRE(tensor.shape().empty()); |
| 123 | + REQUIRE_THROWS_AS(TensorPtr(empty_shape, DataType::FLOAT32, allocator), std::invalid_argument); |
126 | 124 | } |
127 | 125 |
|
128 | 126 | TEST_CASE("Large shape allocation", "[tensor]") |
@@ -263,6 +261,15 @@ class TestInferenceNode : public DeepNodeBase |
263 | 261 | : DeepNodeBase("test_inference_node", options) |
264 | 262 | {} |
265 | 263 |
|
| 264 | + // Expose protected methods for testing |
| 265 | + using DeepNodeBase::get_backend_name; |
| 266 | + using DeepNodeBase::get_current_allocator; |
| 267 | + using DeepNodeBase::is_model_loaded; |
| 268 | + using DeepNodeBase::is_plugin_loaded; |
| 269 | + using DeepNodeBase::load_model; |
| 270 | + using DeepNodeBase::load_plugin; |
| 271 | + using DeepNodeBase::run_inference; |
| 272 | + |
266 | 273 | bool test_load_plugin(const std::string & plugin_name) |
267 | 274 | { |
268 | 275 | return load_plugin(plugin_name); |
@@ -305,10 +312,73 @@ class TestInferenceNode : public DeepNodeBase |
305 | 312 | } |
306 | 313 | }; |
307 | 314 |
|
308 | | -TEST_CASE("DeepNodeBase creation", "[node]") |
| 315 | +TEST_CASE_METHOD(deep_ros::test::TestExecutorFixture, "DeepNodeBase lifecycle management", "[node]") |
309 | 316 | { |
310 | | - // Skip ROS lifecycle node tests to avoid segfault |
311 | | - REQUIRE(true); |
| 317 | + // Create a test node that inherits from DeepNodeBase |
| 318 | + auto test_node = std::make_shared<TestInferenceNode>(); |
| 319 | + add_node(test_node); |
| 320 | + start_spinning(); |
| 321 | + |
| 322 | + SECTION("Node creation and initial state") |
| 323 | + { |
| 324 | + REQUIRE(test_node->get_name() == std::string("test_inference_node")); |
| 325 | + REQUIRE(test_node->get_current_state().id() == lifecycle_msgs::msg::State::PRIMARY_STATE_UNCONFIGURED); |
| 326 | + REQUIRE(test_node->is_plugin_loaded() == false); |
| 327 | + REQUIRE(test_node->is_model_loaded() == false); |
| 328 | + } |
| 329 | + |
| 330 | + SECTION("Lifecycle transitions work correctly") |
| 331 | + { |
| 332 | + // Configure |
| 333 | + auto configure_result = test_node->configure(); |
| 334 | + REQUIRE(configure_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE); |
| 335 | + |
| 336 | + // Activate |
| 337 | + auto activate_result = test_node->activate(); |
| 338 | + REQUIRE(activate_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_ACTIVE); |
| 339 | + |
| 340 | + // Deactivate |
| 341 | + auto deactivate_result = test_node->deactivate(); |
| 342 | + REQUIRE(deactivate_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_INACTIVE); |
| 343 | + |
| 344 | + // Cleanup |
| 345 | + auto cleanup_result = test_node->cleanup(); |
| 346 | + REQUIRE(cleanup_result.id() == lifecycle_msgs::msg::State::PRIMARY_STATE_UNCONFIGURED); |
| 347 | + } |
| 348 | + |
| 349 | + SECTION("Plugin loading functionality") |
| 350 | + { |
| 351 | + // Configure first |
| 352 | + test_node->configure(); |
| 353 | + |
| 354 | + // Try to load a non-existent plugin |
| 355 | + bool load_result = test_node->test_load_plugin("nonexistent_plugin"); |
| 356 | + REQUIRE(load_result == false); |
| 357 | + REQUIRE(test_node->is_plugin_loaded() == false); |
| 358 | + REQUIRE(test_node->get_backend_name() == "none"); |
| 359 | + REQUIRE(test_node->get_current_allocator() == nullptr); |
| 360 | + } |
| 361 | + |
| 362 | + SECTION("Model loading without plugin fails") |
| 363 | + { |
| 364 | + test_node->configure(); |
| 365 | + test_node->activate(); |
| 366 | + |
| 367 | + // Try to load model without plugin |
| 368 | + bool model_result = test_node->test_load_model("/fake/model.onnx"); |
| 369 | + REQUIRE(model_result == false); |
| 370 | + REQUIRE(test_node->is_model_loaded() == false); |
| 371 | + } |
| 372 | + |
| 373 | + SECTION("Backend functionality with no plugin") |
| 374 | + { |
| 375 | + test_node->configure(); |
| 376 | + test_node->activate(); |
| 377 | + |
| 378 | + // Verify backend state when no plugin is loaded |
| 379 | + REQUIRE(test_node->get_backend_name() == "none"); |
| 380 | + REQUIRE(test_node->get_current_allocator() == nullptr); |
| 381 | + } |
312 | 382 | } |
313 | 383 |
|
314 | 384 | } // namespace test |
|
0 commit comments