Skip to content

Commit 0613ce5

Browse files
committed
Cut by ..._track_candidates_per_track as soon as possible
1 parent 8bb5c01 commit 0613ce5

File tree

6 files changed

+58
-46
lines changed

6 files changed

+58
-46
lines changed

device/common/include/traccc/finding/device/impl/build_tracks.ipp

+20-31
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ TRACCC_DEVICE inline void build_tracks(const global_index_t globalIndex,
4444

4545
const unsigned int n_cands = L.step + 1 - L.n_skipped;
4646

47+
// Criteria for valid tracks
48+
assert(n_cands >= cfg.min_track_candidates_per_track &&
49+
n_cands <= cfg.max_track_candidates_per_track);
50+
4751
// Resize the candidates with the exact size
4852
cands_per_track.resize(n_cands);
4953

50-
bool success = true;
51-
5254
// Track summary variables
5355
scalar ndf_sum = 0.f;
5456
scalar chi2_sum = 0.f;
@@ -64,11 +66,7 @@ TRACCC_DEVICE inline void build_tracks(const global_index_t globalIndex,
6466
L = links.at(L.previous_candidate_idx);
6567
}
6668

67-
// Break if the measurement is still invalid
68-
if (L.meas_idx >= measurements.size()) {
69-
success = false;
70-
break;
71-
}
69+
assert(L.meas_idx < n_meas);
7270

7371
*it = {measurements.at(L.meas_idx)};
7472
num_inserted++;
@@ -93,36 +91,27 @@ TRACCC_DEVICE inline void build_tracks(const global_index_t globalIndex,
9391
}
9492

9593
#ifndef NDEBUG
96-
if (success) {
97-
// Assert that we inserted exactly as many elements as we reserved
98-
// space for.
99-
assert(num_inserted == cands_per_track.size());
100-
101-
// Assert that we did not make any duplicate track states.
102-
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
103-
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
104-
if (i != j) {
105-
// TODO: Re-enable me!
106-
// assert(cands_per_track.at(i).measurement_id !=
107-
// cands_per_track.at(j).measurement_id);
108-
}
94+
// Assert that we inserted exactly as many elements as we reserved
95+
// space for.
96+
assert(num_inserted == cands_per_track.size());
97+
98+
// Assert that we did not make any duplicate track states.
99+
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
100+
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
101+
if (i != j) {
102+
// TODO: Re-enable me!
103+
// assert(cands_per_track.at(i).measurement_id !=
104+
// cands_per_track.at(j).measurement_id);
109105
}
110106
}
111107
}
112108
#endif
113109

114-
// NOTE: We may at some point want to assert that `success` is true
110+
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
111+
*payload.n_valid_tracks);
115112

116-
// Criteria for valid tracks
117-
if (n_cands >= cfg.min_track_candidates_per_track &&
118-
n_cands <= cfg.max_track_candidates_per_track && success) {
119-
120-
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
121-
*payload.n_valid_tracks);
122-
123-
const unsigned int pos = num_valid_tracks.fetch_add(1);
124-
valid_indices[pos] = globalIndex;
125-
}
113+
const unsigned int pos = num_valid_tracks.fetch_add(1);
114+
valid_indices[pos] = globalIndex;
126115
}
127116

128117
} // namespace traccc::device

device/common/include/traccc/finding/device/impl/find_tracks.ipp

+8-4
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,14 @@ TRACCC_DEVICE inline void find_tracks(
312312
payload.step == 0 ? 0 : links.at(prev_link_idx).n_skipped;
313313

314314
if (n_skipped >= cfg.max_num_skipping_per_cand) {
315-
// In case of max skipping being 0 and first step being skipped,
316-
// the links are empty, and the tip has nowhere to point
317-
assert(payload.step > 0);
318-
tips.push_back(prev_link_idx);
315+
const unsigned int n_cands = payload.step - n_skipped;
316+
if (n_cands >= cfg.min_track_candidates_per_track) {
317+
// In case of max skipping and min length being 0, and first
318+
// step being skipped, the links are empty, and the tip has
319+
// nowhere to point
320+
assert(payload.step > 0);
321+
tips.push_back(prev_link_idx);
322+
}
319323
} else {
320324
// Add measurement candidates to link
321325
const unsigned int l_pos = links.bulk_append_implicit(1);

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

+21-9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
3030

3131
const unsigned int param_id = param_ids.at(globalIndex);
3232

33+
// Links
34+
vecmem::device_vector<candidate_link> links(payload.links_view);
35+
36+
const unsigned int link_idx = payload.prev_links_idx + param_id;
37+
const auto& link = links.at(link_idx);
38+
assert(link.step == payload.step);
39+
const unsigned int n_cands = link.step + 1 - link.n_skipped;
40+
3341
// Parameter liveness
3442
vecmem::device_vector<unsigned int> params_liveness(
3543
payload.params_liveness_view);
@@ -47,6 +55,16 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
4755
return;
4856
}
4957

58+
// If no more CKF step is expected, current candidate is kept as a tip and
59+
// the time-consuming propagation is avoided
60+
if (payload.step == cfg.max_track_candidates_per_track - 1) {
61+
if (n_cands >= cfg.min_track_candidates_per_track) {
62+
tips.push_back(link_idx);
63+
}
64+
params_liveness[param_id] = 0u;
65+
return;
66+
}
67+
5068
// Input bound track parameter
5169
const bound_track_parameters<> in_par = params.at(param_id);
5270

@@ -82,18 +100,12 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
82100
// If a surface found, add the parameter for the next step
83101
if (s4.success) {
84102
params[param_id] = propagation._stepping.bound_params();
85-
86-
if (payload.step == cfg.max_track_candidates_per_track - 1) {
87-
tips.push_back(payload.prev_links_idx + param_id);
88-
params_liveness[param_id] = 0u;
89-
} else {
90-
params_liveness[param_id] = 1u;
91-
}
103+
params_liveness[param_id] = 1u;
92104
} else {
93105
params_liveness[param_id] = 0u;
94106

95-
if (payload.step >= cfg.min_track_candidates_per_track - 1) {
96-
tips.push_back(payload.prev_links_idx + param_id);
107+
if (n_cands >= cfg.min_track_candidates_per_track) {
108+
tips.push_back(link_idx);
97109
}
98110
}
99111
}

device/common/include/traccc/finding/device/propagate_to_next_surface.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ struct propagate_to_next_surface_payload {
5050
*/
5151
vecmem::data::vector_view<const unsigned int> param_ids_view;
5252

53+
/**
54+
* @brief View object to the link vector
55+
*/
56+
vecmem::data::vector_view<candidate_link> links_view;
57+
5358
/**
5459
* @brief Index in the link vector at which the current step starts
5560
*/

device/cuda/src/finding/finding_algorithm.cu

+1
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
362362
.params_view = in_params_buffer,
363363
.params_liveness_view = param_liveness_buffer,
364364
.param_ids_view = param_ids_buffer,
365+
.links_view = links_buffer,
365366
.prev_links_idx = step_to_link_idx_map[step],
366367
.step = step,
367368
.n_in_params = n_candidates,

device/sycl/src/finding/find_tracks.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ track_candidate_container_types::buffer find_tracks(
396396
param_liveness =
397397
vecmem::get_data(param_liveness_buffer),
398398
param_ids = vecmem::get_data(param_ids_buffer),
399+
links_view = vecmem::get_data(links_buffer),
399400
prev_links_idx = step_to_link_idx_map[step], step,
400401
n_candidates, tips = vecmem::get_data(tips_buffer)](
401402
::sycl::nd_item<1> item) {
@@ -404,8 +405,8 @@ track_candidate_container_types::buffer find_tracks(
404405
typename stepper_t::magnetic_field_type>(
405406
details::global_index(item), config,
406407
{det, field, in_params, param_liveness,
407-
param_ids, prev_links_idx, step, n_candidates,
408-
tips});
408+
param_ids, links_view, prev_links_idx, step,
409+
n_candidates, tips});
409410
});
410411
})
411412
.wait_and_throw();

0 commit comments

Comments
 (0)