Skip to content

Commit 7ad3756

Browse files
committed
Allow to pass reason for navigation abort
1 parent c8eb9fb commit 7ad3756

File tree

5 files changed

+117
-28
lines changed

5 files changed

+117
-28
lines changed

core/include/detray/navigation/direct_navigator.hpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,17 @@ class direct_navigator {
222222
}
223223

224224
DETRAY_HOST_DEVICE
225-
inline auto abort() -> bool {
225+
inline auto abort(const char * = nullptr) -> bool {
226226
m_status = navigation::status::e_abort;
227227
m_heartbeat = false;
228228
return m_heartbeat;
229229
}
230230

231+
template <typename G>
232+
DETRAY_HOST_DEVICE inline auto abort(const G &) -> bool {
233+
return abort();
234+
}
235+
231236
DETRAY_HOST_DEVICE
232237
inline auto exit() -> bool {
233238
m_status = navigation::status::e_on_target;

core/include/detray/navigation/navigator.hpp

+54-3
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,52 @@ class navigator {
394394
/// Navigation state that cannot be recovered from. Leave the other
395395
/// data for inspection.
396396
///
397+
/// @param custom_msg additional information on the reason for the error
398+
///
397399
/// @return navigation heartbeat (dead)
398400
DETRAY_HOST_DEVICE
399-
inline auto abort() -> bool {
401+
inline auto abort(const char *custom_msg = nullptr) -> bool {
400402
m_status = navigation::status::e_abort;
401403
m_heartbeat = false;
402404
// Don't do anything if aborted
403405
m_trust_level = navigation::trust_level::e_full;
406+
407+
/// Wrapper around the custom message that a print inspector can
408+
/// understand
409+
struct message_wrapper {
410+
const char *const m_msg{nullptr};
411+
412+
DETRAY_HOST_DEVICE
413+
constexpr const char *operator()() const { return m_msg; }
414+
};
415+
416+
run_inspector({}, point3_type{0.f, 0.f, 0.f},
417+
vector3_type{0.f, 0.f, 0.f},
418+
"Aborted: ", message_wrapper{custom_msg});
419+
420+
return m_heartbeat;
421+
}
422+
423+
/// Navigation state that cannot be recovered from. Leave the other
424+
/// data for inspection.
425+
///
426+
/// @param debug_msg_generator functor that returns additional
427+
/// information on the reason for the error
428+
///
429+
/// @return navigation heartbeat (dead)
430+
template <typename debug_msg_generator_t>
431+
requires(!std::same_as<char *, debug_msg_generator_t>)
432+
DETRAY_HOST_DEVICE
433+
inline auto abort(const debug_msg_generator_t &debug_msg_generator)
434+
-> bool {
435+
m_status = navigation::status::e_abort;
436+
m_heartbeat = false;
437+
m_trust_level = navigation::trust_level::e_full;
438+
404439
run_inspector({}, point3_type{0.f, 0.f, 0.f},
405-
vector3_type{0.f, 0.f, 0.f}, "Aborted: ");
440+
vector3_type{0.f, 0.f, 0.f},
441+
"Aborted: ", debug_msg_generator);
442+
406443
return m_heartbeat;
407444
}
408445

@@ -579,6 +616,20 @@ class navigator {
579616
}
580617
}
581618

619+
/// Call the navigation inspector
620+
template <typename debug_msg_generator_t>
621+
DETRAY_HOST_DEVICE inline void run_inspector(
622+
[[maybe_unused]] const navigation::config &cfg,
623+
[[maybe_unused]] const point3_type &track_pos,
624+
[[maybe_unused]] const vector3_type &track_dir,
625+
[[maybe_unused]] const char *message,
626+
[[maybe_unused]] const debug_msg_generator_t &msg_gen) {
627+
if constexpr (!std::is_same_v<inspector_t,
628+
navigation::void_inspector>) {
629+
m_inspector(*this, cfg, track_pos, track_dir, message, msg_gen);
630+
}
631+
}
632+
582633
/// Our cache of candidates (intersections with any kind of surface)
583634
candidate_cache_t m_candidates;
584635

@@ -788,7 +839,7 @@ class navigator {
788839
// Unrecoverable
789840
if (navigation.trust_level() != navigation::trust_level::e_full ||
790841
navigation.is_exhausted()) {
791-
navigation.abort();
842+
navigation.abort("Navigator: No reachable surfaces");
792843
}
793844

794845
return is_init;

core/include/detray/propagator/actors/aborters.hpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ struct pathlimit_aborter : actor {
5858
// Check the path limit
5959
if (step_limit <= 0.f) {
6060
// Stop navigation
61-
prop_state._heartbeat &= nav_state.abort();
61+
prop_state._heartbeat &=
62+
nav_state.abort("Aborter: Maximal path length reached");
6263
}
6364

6465
// Don't go over the path limit in the next step
@@ -122,7 +123,8 @@ struct momentum_aborter : actor {
122123

123124
if (mag <= abrt_state.p_limit()) {
124125
// Stop navigation
125-
prop_state._heartbeat &= nav_state.abort();
126+
prop_state._heartbeat &=
127+
nav_state.abort("Aborter: Minimum momentum reached");
126128
}
127129
}
128130
};
@@ -152,7 +154,8 @@ struct target_aborter : actor {
152154
if (navigation.is_on_surface() &&
153155
(navigation.barcode() == abrt_state._target_surface) &&
154156
(stepping.path_length() > 0.f)) {
155-
prop_state._heartbeat &= navigation.abort();
157+
prop_state._heartbeat &=
158+
navigation.abort("Aborter: Reached target surface");
156159
}
157160
}
158161
};

tests/include/detray/test/utils/inspectors.hpp

+48-18
Original file line numberDiff line numberDiff line change
@@ -56,32 +56,38 @@ struct aggregate_inspector {
5656

5757
/// Inspector interface
5858
template <unsigned int current_id = 0, typename state_type,
59-
concepts::point3D point3_t, concepts::vector3D vector3_t>
59+
concepts::point3D point3_t, concepts::vector3D vector3_t,
60+
typename... Args>
6061
DETRAY_HOST_DEVICE auto operator()(state_type &state,
6162
const navigation::config &cfg,
6263
const point3_t &pos,
6364
const vector3_t &dir,
64-
const char *message) {
65+
const char *message, Args &&... args) {
6566
// Call inspector
66-
std::get<current_id>(_inspectors)(state, cfg, pos, dir, message);
67+
std::get<current_id>(_inspectors)(state, cfg, pos, dir, message,
68+
std::forward<Args>(args)...);
6769

6870
// Next inspector
6971
if constexpr (current_id <
7072
std::tuple_size<inspector_tuple_t>::value - 1) {
71-
return operator()<current_id + 1>(state, cfg, pos, dir, message);
73+
return operator()<current_id + 1>(state, cfg, pos, dir, message,
74+
std::forward<Args>(args)...);
7275
}
7376
}
7477

7578
/// Inspector interface
76-
template <unsigned int current_id = 0, typename state_type>
77-
DETRAY_HOST_DEVICE auto operator()(state_type &state, const char *message) {
79+
template <unsigned int current_id = 0, typename state_type,
80+
typename... Args>
81+
DETRAY_HOST_DEVICE auto operator()(state_type &state, const char *message,
82+
Args &&... args) {
7883
// Call inspector
7984
std::get<current_id>(_inspectors)(state, message);
8085

8186
// Next inspector
8287
if constexpr (current_id <
8388
std::tuple_size<inspector_tuple_t>::value - 1) {
84-
return operator()<current_id + 1>(state, message);
89+
return operator()<current_id + 1>(state, message,
90+
std::forward<Args>(args)...);
8591
}
8692
}
8793

@@ -191,12 +197,12 @@ struct object_tracer {
191197

192198
/// Inspector interface
193199
template <typename state_type, concepts::point3D point3_t,
194-
concepts::vector3D vector3_t>
200+
concepts::vector3D vector3_t, typename... Args>
195201
DETRAY_HOST_DEVICE auto operator()(const state_type &state,
196202
const navigation::config &,
197203
const point3_t &pos,
198204
const vector3_t &dir,
199-
const char * /*message*/) {
205+
const char * /*message*/, Args &&...) {
200206

201207
// Record the candidate of an encountered object
202208
if ((is_status(state.status(), navigation_status) || ...)) {
@@ -245,6 +251,8 @@ struct print_inspector {
245251
using view_type = dvector_view<char>;
246252
using const_view_type = dvector_view<const char>;
247253

254+
struct void_generator {};
255+
248256
/// Default constructor
249257
print_inspector() = default;
250258

@@ -272,17 +280,24 @@ struct print_inspector {
272280
/// Move assignemten operator
273281
print_inspector &operator=(print_inspector &&other) = default;
274282

275-
/// Gathers navigation information accross navigator update calls
276-
std::stringstream debug_stream{};
277-
278283
/// Inspector interface. Gathers detailed information during navigation
279284
template <typename state_type, concepts::point3D point3_t,
280-
concepts::vector3D vector3_t>
285+
concepts::vector3D vector3_t,
286+
typename message_generator_t = void_generator>
281287
auto operator()(const state_type &state, const navigation::config &cfg,
282288
const point3_t &track_pos, const vector3_t &track_dir,
283-
const char *message) {
289+
const char *message,
290+
const message_generator_t &msg_gen = {}) {
284291
std::string msg(message);
285-
debug_stream << msg << std::endl;
292+
debug_stream << msg;
293+
if constexpr (!std::same_as<message_generator_t, void_generator>) {
294+
debug_stream << msg_gen();
295+
296+
if (state.status() == navigation::status::e_abort) {
297+
fata_error_msg = msg_gen();
298+
}
299+
}
300+
debug_stream << std::endl;
286301
debug_stream << "----------------------------------------" << std::endl;
287302

288303
debug_stream << navigation::print_state(state);
@@ -293,10 +308,20 @@ struct print_inspector {
293308
}
294309

295310
/// Inspector interface. Print basic state information
296-
template <typename state_type>
297-
auto operator()(const state_type &state, const char *message) {
311+
template <typename state_type,
312+
typename message_generator_t = void_generator>
313+
auto operator()(const state_type &state, const char *message,
314+
const message_generator_t &msg_gen = {}) {
298315
std::string msg(message);
299-
debug_stream << msg << std::endl;
316+
debug_stream << msg;
317+
if constexpr (!std::same_as<message_generator_t, void_generator>) {
318+
debug_stream << msg_gen();
319+
320+
if (state.status() == navigation::status::e_abort) {
321+
fata_error_msg = msg_gen();
322+
}
323+
}
324+
debug_stream << std::endl;
300325
debug_stream << "----------------------------------------" << std::endl;
301326

302327
debug_stream << navigation::print_state(state);
@@ -306,6 +331,11 @@ struct print_inspector {
306331

307332
/// @returns a string representation of the gathered information
308333
std::string to_string() const { return debug_stream.str(); }
334+
335+
/// Gathers navigation information accross navigator update calls
336+
std::stringstream debug_stream{};
337+
/// Special message that is collected if the navigator hits a fatal error
338+
std::string fata_error_msg{""};
309339
};
310340

311341
} // namespace navigation

tests/include/detray/test/validation/navigation_validation_utils.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1074,10 +1074,10 @@ auto compare_to_navigation(
10741074

10751075
// Fatal propagation error: Data unreliable
10761076
if (!success) {
1077-
std::cout << "ERROR: Propagation failure" << std::endl;
1077+
std::cout << "ERROR: Propagation aborted! "
1078+
<< nav_printer.fata_error_msg << std::endl;
10781079

1079-
*debug_file << "ERROR: Propagation failure:\n"
1080-
<< "TEST TRACK " << i;
1080+
*debug_file << "ERROR: Propagation aborted:" << std::endl;
10811081

10821082
n_fatal_error++;
10831083
}

0 commit comments

Comments
 (0)