5151
5252namespace autoware ::diffusion_planner
5353{
54+ using diagnostic_msgs::msg::DiagnosticStatus;
55+
5456DiffusionPlanner::DiffusionPlanner (const rclcpp::NodeOptions & options)
5557: Node(" diffusion_planner" , options), generator_uuid_(autoware_utils_uuid::generate_uuid())
5658{
@@ -71,22 +73,23 @@ DiffusionPlanner::DiffusionPlanner(const rclcpp::NodeOptions & options)
7173 turn_indicator_manager_.set_hold_duration (
7274 rclcpp::Duration::from_seconds (params_.turn_indicator_hold_duration ));
7375 turn_indicator_manager_.set_keep_offset (params_.turn_indicator_keep_offset );
74- utils::check_weight_version (params_.args_path );
75- normalization_map_ = utils::load_normalization_stats (params_.args_path );
7676
7777 diagnostics_inference_ = std::make_unique<DiagnosticsInterface>(this , " inference_status" );
78- diagnostics_inference_->update_level_and_message (
79- diagnostic_msgs::msg::DiagnosticStatus::WARN, " Loading model weights" );
80- diagnostics_inference_->publish (get_clock ()->now ());
81- tensorrt_inference_ = std::make_unique<TensorrtInference>(
82- params_.model_path , params_.plugins_path , params_.batch_size );
83- diagnostics_inference_->update_level_and_message (
84- diagnostic_msgs::msg::DiagnosticStatus::OK, " Model weights loaded" );
85- diagnostics_inference_->publish (get_clock ()->now ());
86-
87- if (params_.build_only ) {
88- RCLCPP_INFO (get_logger (), " Build only mode enabled. Exiting after loading model." );
89- std::exit (EXIT_SUCCESS);
78+ try {
79+ load_model ();
80+ if (params_.build_only ) {
81+ RCLCPP_INFO (get_logger (), " Build only mode enabled. Exiting after loading model." );
82+ std::exit (EXIT_SUCCESS);
83+ }
84+ } catch (const std::exception & e) {
85+ RCLCPP_ERROR_STREAM (get_logger (), e.what () << " . Inference will be disabled." );
86+ tensorrt_inference_.reset ();
87+ diagnostics_inference_->update_level_and_message (DiagnosticStatus::ERROR, e.what ());
88+ diagnostics_inference_->publish (get_clock ()->now ());
89+ if (params_.build_only ) {
90+ RCLCPP_ERROR (get_logger (), " Build only mode: exiting due to model load failure." );
91+ std::exit (EXIT_FAILURE);
92+ }
9093 }
9194
9295 vehicle_info_ = autoware::vehicle_info_utils::VehicleInfoUtils (*this ).getVehicleInfo ();
@@ -141,15 +144,29 @@ void DiffusionPlanner::set_up_params()
141144 this ->declare_parameter <bool >(" debug_params.publish_debug_route" , true );
142145}
143146
147+ void DiffusionPlanner::load_model ()
148+ {
149+ utils::check_weight_version (params_.args_path );
150+ normalization_map_ = utils::load_normalization_stats (params_.args_path );
151+ diagnostics_inference_->update_level_and_message (DiagnosticStatus::WARN, " Loading model" );
152+ diagnostics_inference_->publish (get_clock ()->now ());
153+ tensorrt_inference_ = std::make_unique<TensorrtInference>(
154+ params_.model_path , params_.plugins_path , params_.batch_size );
155+ diagnostics_inference_->update_level_and_message (DiagnosticStatus::OK, " Model loaded" );
156+ diagnostics_inference_->publish (get_clock ()->now ());
157+ }
158+
144159SetParametersResult DiffusionPlanner::on_parameter (
145160 [[maybe_unused]] const std::vector<rclcpp::Parameter> & parameters)
146161{
147162 using autoware_utils::update_param;
148163 {
149164 DiffusionPlannerParams temp_params = params_;
165+ const auto previous_args_path = params_.args_path ;
150166 const auto previous_model_path = params_.model_path ;
151167 const auto previous_batch_size = params_.batch_size ;
152168 update_param<std::string>(parameters, " onnx_model_path" , temp_params.model_path );
169+ update_param<std::string>(parameters, " args_path" , temp_params.args_path );
153170 update_param<bool >(
154171 parameters, " ignore_unknown_neighbors" , temp_params.ignore_unknown_neighbors );
155172 update_param<bool >(parameters, " ignore_neighbors" , temp_params.ignore_neighbors );
@@ -168,22 +185,25 @@ SetParametersResult DiffusionPlanner::on_parameter(
168185 update_param<double >(
169186 parameters, " turn_indicator_hold_duration" , temp_params.turn_indicator_hold_duration );
170187 update_param<bool >(parameters, " shift_x" , temp_params.shift_x );
188+ const bool args_path_changed = temp_params.args_path != previous_args_path;
171189 const bool model_path_changed = temp_params.model_path != previous_model_path;
172190 const bool batch_size_changed = temp_params.batch_size != previous_batch_size;
173191 params_ = temp_params;
174192 turn_indicator_manager_.set_hold_duration (
175193 rclcpp::Duration::from_seconds (params_.turn_indicator_hold_duration ));
176194 turn_indicator_manager_.set_keep_offset (params_.turn_indicator_keep_offset );
177195
178- if ((model_path_changed || batch_size_changed) && tensorrt_inference_) {
179- diagnostics_inference_->update_level_and_message (
180- diagnostic_msgs::msg::DiagnosticStatus::WARN, " Loading model weights" );
181- diagnostics_inference_->publish (get_clock ()->now ());
182- tensorrt_inference_ = std::make_unique<TensorrtInference>(
183- params_.model_path , params_.plugins_path , params_.batch_size );
184- diagnostics_inference_->update_level_and_message (
185- diagnostic_msgs::msg::DiagnosticStatus::OK, " Model weights loaded" );
186- diagnostics_inference_->publish (get_clock ()->now ());
196+ if (args_path_changed || model_path_changed || batch_size_changed) {
197+ try {
198+ load_model ();
199+ } catch (const std::exception & e) {
200+ RCLCPP_ERROR_STREAM (get_logger (), e.what () << " . Failed to reload model." );
201+ tensorrt_inference_.reset ();
202+ SetParametersResult result;
203+ result.successful = false ;
204+ result.reason = e.what ();
205+ return result;
206+ }
187207 }
188208 }
189209
@@ -516,12 +536,20 @@ void DiffusionPlanner::on_timer()
516536 diagnostics_inference_->clear ();
517537
518538 const rclcpp::Time current_time (get_clock ()->now ());
539+ if (!tensorrt_inference_) {
540+ RCLCPP_WARN_THROTTLE (
541+ get_logger (), *this ->get_clock (), constants::LOG_THROTTLE_INTERVAL_MS,
542+ " Model not loaded. Inference is disabled. Check onnx_model_path and args_path parameters." );
543+ diagnostics_inference_->update_level_and_message (DiagnosticStatus::ERROR, " Model not loaded" );
544+ diagnostics_inference_->publish (current_time);
545+ return ;
546+ }
547+
519548 if (!lane_segment_context_) {
520549 RCLCPP_INFO_THROTTLE (
521550 get_logger (), *this ->get_clock (), constants::LOG_THROTTLE_INTERVAL_MS,
522551 " Waiting for map data..." );
523- diagnostics_inference_->update_level_and_message (
524- diagnostic_msgs::msg::DiagnosticStatus::WARN, " Map data not loaded" );
552+ diagnostics_inference_->update_level_and_message (DiagnosticStatus::WARN, " Map data not loaded" );
525553 diagnostics_inference_->publish (current_time);
526554 return ;
527555 }
@@ -533,7 +561,7 @@ void DiffusionPlanner::on_timer()
533561 get_logger (), *this ->get_clock (), constants::LOG_THROTTLE_INTERVAL_MS,
534562 " No input data available for inference" );
535563 diagnostics_inference_->update_level_and_message (
536- diagnostic_msgs::msg:: DiagnosticStatus::WARN, " No input data available for inference" );
564+ DiagnosticStatus::WARN, " No input data available for inference" );
537565 diagnostics_inference_->publish (current_time);
538566 return ;
539567 }
@@ -575,7 +603,7 @@ void DiffusionPlanner::on_timer()
575603 get_logger (), *this ->get_clock (), constants::LOG_THROTTLE_INTERVAL_MS,
576604 " Input data contains invalid values" );
577605 diagnostics_inference_->update_level_and_message (
578- diagnostic_msgs::msg:: DiagnosticStatus::WARN, " Input data contains invalid values" );
606+ DiagnosticStatus::WARN, " Input data contains invalid values" );
579607 diagnostics_inference_->publish (current_time);
580608 return ;
581609 }
@@ -587,7 +615,7 @@ void DiffusionPlanner::on_timer()
587615 get_logger (), *this ->get_clock (), constants::LOG_THROTTLE_INTERVAL_MS,
588616 " Inference failed: " << inference_result.error_msg );
589617 diagnostics_inference_->update_level_and_message (
590- diagnostic_msgs::msg:: DiagnosticStatus::ERROR, inference_result.error_msg );
618+ DiagnosticStatus::ERROR, inference_result.error_msg );
591619 diagnostics_inference_->publish (frame_time);
592620 return ;
593621 }
0 commit comments