Skip to content

Commit 2f8fc4f

Browse files
fix(diffusion_planner): modify loading model (#12125)
* Added `load_model` Signed-off-by: Shintaro Sakoda <shintaro.sakoda@tier4.jp> * Fixed to fail Signed-off-by: Shintaro Sakoda <shintaro.sakoda@tier4.jp> * Fixed Signed-off-by: Shintaro Sakoda <shintaro.sakoda@tier4.jp> * Removed `reject_` from the result variable Signed-off-by: Shintaro Sakoda <shintaro.sakoda@tier4.jp> * Update planning/autoware_diffusion_planner/include/autoware/diffusion_planner/diffusion_planner_node.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fixed the timing of checking `build_only` Signed-off-by: Shintaro Sakoda <shintaro.sakoda@tier4.jp> --------- Signed-off-by: Shintaro Sakoda <shintaro.sakoda@tier4.jp> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 016b29f commit 2f8fc4f

File tree

2 files changed

+66
-28
lines changed

2 files changed

+66
-28
lines changed

planning/autoware_diffusion_planner/include/autoware/diffusion_planner/diffusion_planner_node.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ class DiffusionPlanner : public rclcpp::Node
183183
*/
184184
void set_up_params();
185185

186+
/**
187+
* @brief Load TensorRT model and normalization statistics.
188+
*
189+
* Updates the normalization_map_ and tensorrt_inference_ member variables.
190+
*
191+
* @throws std::runtime_error if args_path or model_path are invalid, if the
192+
* model version is incompatible, or if TensorRT engine setup fails.
193+
*/
194+
void load_model();
195+
186196
/**
187197
* @brief Timer callback for periodic processing and publishing.
188198
*/

planning/autoware_diffusion_planner/src/diffusion_planner_node.cpp

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151

5252
namespace autoware::diffusion_planner
5353
{
54+
using diagnostic_msgs::msg::DiagnosticStatus;
55+
5456
DiffusionPlanner::DiffusionPlanner(const rclcpp::NodeOptions & options)
5557
: Node("diffusion_planner", options), generator_uuid_(autoware_utils_uuid::generate_uuid())
5658
{
@@ -71,22 +73,23 @@ DiffusionPlanner::DiffusionPlanner(const rclcpp::NodeOptions & options)
7173
turn_indicator_manager_.set_hold_duration(
7274
rclcpp::Duration::from_seconds(params_.turn_indicator_hold_duration));
7375
turn_indicator_manager_.set_keep_offset(params_.turn_indicator_keep_offset);
74-
utils::check_weight_version(params_.args_path);
75-
normalization_map_ = utils::load_normalization_stats(params_.args_path);
7676

7777
diagnostics_inference_ = std::make_unique<DiagnosticsInterface>(this, "inference_status");
78-
diagnostics_inference_->update_level_and_message(
79-
diagnostic_msgs::msg::DiagnosticStatus::WARN, "Loading model weights");
80-
diagnostics_inference_->publish(get_clock()->now());
81-
tensorrt_inference_ = std::make_unique<TensorrtInference>(
82-
params_.model_path, params_.plugins_path, params_.batch_size);
83-
diagnostics_inference_->update_level_and_message(
84-
diagnostic_msgs::msg::DiagnosticStatus::OK, "Model weights loaded");
85-
diagnostics_inference_->publish(get_clock()->now());
86-
87-
if (params_.build_only) {
88-
RCLCPP_INFO(get_logger(), "Build only mode enabled. Exiting after loading model.");
89-
std::exit(EXIT_SUCCESS);
78+
try {
79+
load_model();
80+
if (params_.build_only) {
81+
RCLCPP_INFO(get_logger(), "Build only mode enabled. Exiting after loading model.");
82+
std::exit(EXIT_SUCCESS);
83+
}
84+
} catch (const std::exception & e) {
85+
RCLCPP_ERROR_STREAM(get_logger(), e.what() << ". Inference will be disabled.");
86+
tensorrt_inference_.reset();
87+
diagnostics_inference_->update_level_and_message(DiagnosticStatus::ERROR, e.what());
88+
diagnostics_inference_->publish(get_clock()->now());
89+
if (params_.build_only) {
90+
RCLCPP_ERROR(get_logger(), "Build only mode: exiting due to model load failure.");
91+
std::exit(EXIT_FAILURE);
92+
}
9093
}
9194

9295
vehicle_info_ = autoware::vehicle_info_utils::VehicleInfoUtils(*this).getVehicleInfo();
@@ -141,15 +144,29 @@ void DiffusionPlanner::set_up_params()
141144
this->declare_parameter<bool>("debug_params.publish_debug_route", true);
142145
}
143146

147+
void DiffusionPlanner::load_model()
148+
{
149+
utils::check_weight_version(params_.args_path);
150+
normalization_map_ = utils::load_normalization_stats(params_.args_path);
151+
diagnostics_inference_->update_level_and_message(DiagnosticStatus::WARN, "Loading model");
152+
diagnostics_inference_->publish(get_clock()->now());
153+
tensorrt_inference_ = std::make_unique<TensorrtInference>(
154+
params_.model_path, params_.plugins_path, params_.batch_size);
155+
diagnostics_inference_->update_level_and_message(DiagnosticStatus::OK, "Model loaded");
156+
diagnostics_inference_->publish(get_clock()->now());
157+
}
158+
144159
SetParametersResult DiffusionPlanner::on_parameter(
145160
[[maybe_unused]] const std::vector<rclcpp::Parameter> & parameters)
146161
{
147162
using autoware_utils::update_param;
148163
{
149164
DiffusionPlannerParams temp_params = params_;
165+
const auto previous_args_path = params_.args_path;
150166
const auto previous_model_path = params_.model_path;
151167
const auto previous_batch_size = params_.batch_size;
152168
update_param<std::string>(parameters, "onnx_model_path", temp_params.model_path);
169+
update_param<std::string>(parameters, "args_path", temp_params.args_path);
153170
update_param<bool>(
154171
parameters, "ignore_unknown_neighbors", temp_params.ignore_unknown_neighbors);
155172
update_param<bool>(parameters, "ignore_neighbors", temp_params.ignore_neighbors);
@@ -168,22 +185,25 @@ SetParametersResult DiffusionPlanner::on_parameter(
168185
update_param<double>(
169186
parameters, "turn_indicator_hold_duration", temp_params.turn_indicator_hold_duration);
170187
update_param<bool>(parameters, "shift_x", temp_params.shift_x);
188+
const bool args_path_changed = temp_params.args_path != previous_args_path;
171189
const bool model_path_changed = temp_params.model_path != previous_model_path;
172190
const bool batch_size_changed = temp_params.batch_size != previous_batch_size;
173191
params_ = temp_params;
174192
turn_indicator_manager_.set_hold_duration(
175193
rclcpp::Duration::from_seconds(params_.turn_indicator_hold_duration));
176194
turn_indicator_manager_.set_keep_offset(params_.turn_indicator_keep_offset);
177195

178-
if ((model_path_changed || batch_size_changed) && tensorrt_inference_) {
179-
diagnostics_inference_->update_level_and_message(
180-
diagnostic_msgs::msg::DiagnosticStatus::WARN, "Loading model weights");
181-
diagnostics_inference_->publish(get_clock()->now());
182-
tensorrt_inference_ = std::make_unique<TensorrtInference>(
183-
params_.model_path, params_.plugins_path, params_.batch_size);
184-
diagnostics_inference_->update_level_and_message(
185-
diagnostic_msgs::msg::DiagnosticStatus::OK, "Model weights loaded");
186-
diagnostics_inference_->publish(get_clock()->now());
196+
if (args_path_changed || model_path_changed || batch_size_changed) {
197+
try {
198+
load_model();
199+
} catch (const std::exception & e) {
200+
RCLCPP_ERROR_STREAM(get_logger(), e.what() << ". Failed to reload model.");
201+
tensorrt_inference_.reset();
202+
SetParametersResult result;
203+
result.successful = false;
204+
result.reason = e.what();
205+
return result;
206+
}
187207
}
188208
}
189209

@@ -516,12 +536,20 @@ void DiffusionPlanner::on_timer()
516536
diagnostics_inference_->clear();
517537

518538
const rclcpp::Time current_time(get_clock()->now());
539+
if (!tensorrt_inference_) {
540+
RCLCPP_WARN_THROTTLE(
541+
get_logger(), *this->get_clock(), constants::LOG_THROTTLE_INTERVAL_MS,
542+
"Model not loaded. Inference is disabled. Check onnx_model_path and args_path parameters.");
543+
diagnostics_inference_->update_level_and_message(DiagnosticStatus::ERROR, "Model not loaded");
544+
diagnostics_inference_->publish(current_time);
545+
return;
546+
}
547+
519548
if (!lane_segment_context_) {
520549
RCLCPP_INFO_THROTTLE(
521550
get_logger(), *this->get_clock(), constants::LOG_THROTTLE_INTERVAL_MS,
522551
"Waiting for map data...");
523-
diagnostics_inference_->update_level_and_message(
524-
diagnostic_msgs::msg::DiagnosticStatus::WARN, "Map data not loaded");
552+
diagnostics_inference_->update_level_and_message(DiagnosticStatus::WARN, "Map data not loaded");
525553
diagnostics_inference_->publish(current_time);
526554
return;
527555
}
@@ -533,7 +561,7 @@ void DiffusionPlanner::on_timer()
533561
get_logger(), *this->get_clock(), constants::LOG_THROTTLE_INTERVAL_MS,
534562
"No input data available for inference");
535563
diagnostics_inference_->update_level_and_message(
536-
diagnostic_msgs::msg::DiagnosticStatus::WARN, "No input data available for inference");
564+
DiagnosticStatus::WARN, "No input data available for inference");
537565
diagnostics_inference_->publish(current_time);
538566
return;
539567
}
@@ -575,7 +603,7 @@ void DiffusionPlanner::on_timer()
575603
get_logger(), *this->get_clock(), constants::LOG_THROTTLE_INTERVAL_MS,
576604
"Input data contains invalid values");
577605
diagnostics_inference_->update_level_and_message(
578-
diagnostic_msgs::msg::DiagnosticStatus::WARN, "Input data contains invalid values");
606+
DiagnosticStatus::WARN, "Input data contains invalid values");
579607
diagnostics_inference_->publish(current_time);
580608
return;
581609
}
@@ -587,7 +615,7 @@ void DiffusionPlanner::on_timer()
587615
get_logger(), *this->get_clock(), constants::LOG_THROTTLE_INTERVAL_MS,
588616
"Inference failed: " << inference_result.error_msg);
589617
diagnostics_inference_->update_level_and_message(
590-
diagnostic_msgs::msg::DiagnosticStatus::ERROR, inference_result.error_msg);
618+
DiagnosticStatus::ERROR, inference_result.error_msg);
591619
diagnostics_inference_->publish(frame_time);
592620
return;
593621
}

0 commit comments

Comments
 (0)