Skip to content

Commit 974a329

Browse files
committed
Cut by ..._track_candidates_per_track as soon as possible
1 parent 643e991 commit 974a329

File tree

7 files changed

+83
-57
lines changed

7 files changed

+83
-57
lines changed

device/alpaka/src/finding/finding_algorithm.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
309309
::alpaka::wait(queue);
310310
}
311311

312+
if (step == m_cfg.max_track_candidates_per_track - 1) {
313+
break;
314+
}
315+
312316
if (n_candidates > 0) {
313317
/*****************************************************************
314318
* Kernel4: Get key and value for parameter sorting

device/common/include/traccc/finding/device/impl/build_tracks.ipp

+22-32
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
namespace traccc::device {
1414

1515
TRACCC_HOST_DEVICE inline void build_tracks(
16-
const global_index_t globalIndex, const finding_config& cfg,
16+
const global_index_t globalIndex,
17+
[[maybe_unused]] const finding_config& cfg,
1718
const build_tracks_payload& payload) {
1819

1920
const measurement_collection_types::const_device measurements(
@@ -47,11 +48,13 @@ TRACCC_HOST_DEVICE inline void build_tracks(
4748

4849
const unsigned int n_cands = L.step + 1 - L.n_skipped;
4950

51+
// Criteria for valid tracks
52+
assert(n_cands >= cfg.min_track_candidates_per_track &&
53+
n_cands <= cfg.max_track_candidates_per_track);
54+
5055
// Resize the candidates with the exact size
5156
cands_per_track.resize(n_cands);
5257

53-
bool success = true;
54-
5558
// Track summary variables
5659
scalar ndf_sum = 0.f;
5760
scalar chi2_sum = 0.f;
@@ -67,11 +70,7 @@ TRACCC_HOST_DEVICE inline void build_tracks(
6770
L = links.at(L.previous_candidate_idx);
6871
}
6972

70-
// Break if the measurement is still invalid
71-
if (L.meas_idx >= measurements.size()) {
72-
success = false;
73-
break;
74-
}
73+
assert(L.meas_idx < n_meas);
7574

7675
*it = {measurements.at(L.meas_idx)};
7776
num_inserted++;
@@ -97,36 +96,27 @@ TRACCC_HOST_DEVICE inline void build_tracks(
9796
}
9897

9998
#ifndef NDEBUG
100-
if (success) {
101-
// Assert that we inserted exactly as many elements as we reserved
102-
// space for.
103-
assert(num_inserted == cands_per_track.size());
104-
105-
// Assert that we did not make any duplicate track states.
106-
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
107-
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
108-
if (i != j) {
109-
// TODO: Re-enable me!
110-
// assert(cands_per_track.at(i).measurement_id !=
111-
// cands_per_track.at(j).measurement_id);
112-
}
99+
// Assert that we inserted exactly as many elements as we reserved
100+
// space for.
101+
assert(num_inserted == cands_per_track.size());
102+
103+
// Assert that we did not make any duplicate track states.
104+
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
105+
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
106+
if (i != j) {
107+
// TODO: Re-enable me!
108+
// assert(cands_per_track.at(i).measurement_id !=
109+
// cands_per_track.at(j).measurement_id);
113110
}
114111
}
115112
}
116113
#endif
117114

118-
// NOTE: We may at some point want to assert that `success` is true
115+
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
116+
*payload.n_valid_tracks);
119117

120-
// Criteria for valid tracks
121-
if (n_cands >= cfg.min_track_candidates_per_track &&
122-
n_cands <= cfg.max_track_candidates_per_track && success) {
123-
124-
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
125-
*payload.n_valid_tracks);
126-
127-
const unsigned int pos = num_valid_tracks.fetch_add(1);
128-
valid_indices[pos] = globalIndex;
129-
}
118+
const unsigned int pos = num_valid_tracks.fetch_add(1);
119+
valid_indices[pos] = globalIndex;
130120
}
131121

132122
} // namespace traccc::device

device/common/include/traccc/finding/device/impl/find_tracks.ipp

+22-6
Original file line numberDiff line numberDiff line change
@@ -256,16 +256,28 @@ TRACCC_HOST_DEVICE inline void find_tracks(
256256
const unsigned int n_skipped =
257257
payload.step == 0 ? 0 : links.at(prev_link_idx).n_skipped;
258258

259+
const bool last_step =
260+
payload.step == cfg.max_track_candidates_per_track - 1;
261+
259262
links.at(l_pos) = {.step = payload.step,
260263
.previous_candidate_idx = prev_link_idx,
261264
.meas_idx = meas_idx,
262265
.seed_idx = seed_idx,
263266
.n_skipped = n_skipped,
264267
.chi2 = chi2};
265-
266268
out_params.at(l_pos - payload.curr_links_idx) =
267269
trk_state.filtered();
268-
out_params_liveness.at(l_pos - payload.curr_links_idx) = 1u;
270+
out_params_liveness.at(l_pos - payload.curr_links_idx) =
271+
static_cast<unsigned int>(!last_step);
272+
273+
const unsigned int n_cands = payload.step + 1 - n_skipped;
274+
275+
// If no more CKF step is expected, current candidate is kept as
276+
// a tip
277+
if (last_step &&
278+
n_cands >= cfg.min_track_candidates_per_track) {
279+
tips.push_back(l_pos);
280+
}
269281
}
270282
}
271283

@@ -318,10 +330,14 @@ TRACCC_HOST_DEVICE inline void find_tracks(
318330
payload.step == 0 ? 0 : links.at(prev_link_idx).n_skipped;
319331

320332
if (n_skipped >= cfg.max_num_skipping_per_cand) {
321-
// In case of max skipping being 0 and first step being skipped,
322-
// the links are empty, and the tip has nowhere to point
323-
assert(payload.step > 0);
324-
tips.push_back(prev_link_idx);
333+
const unsigned int n_cands = payload.step - n_skipped;
334+
if (n_cands >= cfg.min_track_candidates_per_track) {
335+
// In case of max skipping and min length being 0, and first
336+
// step being skipped, the links are empty, and the tip has
337+
// nowhere to point
338+
assert(payload.step > 0);
339+
tips.push_back(prev_link_idx);
340+
}
325341
} else {
326342
// Add measurement candidates to link
327343
const unsigned int l_pos = links.bulk_append_implicit(1);

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

+11-9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
3030

3131
const unsigned int param_id = param_ids.at(globalIndex);
3232

33+
// Links
34+
vecmem::device_vector<const candidate_link> links(payload.links_view);
35+
36+
const unsigned int link_idx = payload.prev_links_idx + param_id;
37+
const auto& link = links.at(link_idx);
38+
assert(link.step == payload.step);
39+
const unsigned int n_cands = link.step + 1 - link.n_skipped;
40+
3341
// Parameter liveness
3442
vecmem::device_vector<unsigned int> params_liveness(
3543
payload.params_liveness_view);
@@ -82,18 +90,12 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
8290
// If a surface found, add the parameter for the next step
8391
if (s4.success) {
8492
params[param_id] = propagation._stepping.bound_params();
85-
86-
if (payload.step == cfg.max_track_candidates_per_track - 1) {
87-
tips.push_back(payload.prev_links_idx + param_id);
88-
params_liveness[param_id] = 0u;
89-
} else {
90-
params_liveness[param_id] = 1u;
91-
}
93+
params_liveness[param_id] = 1u;
9294
} else {
9395
params_liveness[param_id] = 0u;
9496

95-
if (payload.step >= cfg.min_track_candidates_per_track - 1) {
96-
tips.push_back(payload.prev_links_idx + param_id);
97+
if (n_cands >= cfg.min_track_candidates_per_track) {
98+
tips.push_back(link_idx);
9799
}
98100
}
99101
}

device/common/include/traccc/finding/device/propagate_to_next_surface.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ struct propagate_to_next_surface_payload {
5050
*/
5151
vecmem::data::vector_view<const unsigned int> param_ids_view;
5252

53+
/**
54+
* @brief View object to the vector of candidate links
55+
*/
56+
vecmem::data::vector_view<const candidate_link> links_view;
57+
5358
/**
5459
* @brief Index in the link vector at which the current step starts
5560
*/

device/cuda/src/finding/finding_algorithm.cu

+12-8
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
189189

190190
unsigned int n_in_params = n_seeds;
191191

192-
for (unsigned int step = 0;
193-
step < m_cfg.max_track_candidates_per_track && n_in_params > 0;
194-
step++) {
192+
for (unsigned int step = 0; n_in_params > 0; step++) {
195193

196194
/*****************************************************************
197195
* Kernel2: Apply material interaction
@@ -268,13 +266,12 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
268266
const unsigned int nThreads = m_warp_size * 2;
269267
const unsigned int nBlocks =
270268
(n_in_params + nThreads - 1) / nThreads;
269+
const std::size_t shared_size =
270+
nThreads * sizeof(unsigned int) +
271+
2 * nThreads * sizeof(std::pair<unsigned int, unsigned int>);
271272

272273
kernels::find_tracks<std::decay_t<detector_type>>
273-
<<<nBlocks, nThreads,
274-
nThreads * sizeof(unsigned int) +
275-
2 * nThreads *
276-
sizeof(std::pair<unsigned int, unsigned int>),
277-
stream>>>(
274+
<<<nBlocks, nThreads, shared_size, stream>>>(
278275
m_cfg,
279276
device::find_tracks_payload<std::decay_t<detector_type>>{
280277
.det_data = det_view,
@@ -306,6 +303,12 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
306303
m_stream.synchronize();
307304
}
308305

306+
// If no more CKF step is expected, the tips and links are populated,
307+
// and any further time-consuming action is avoided
308+
if (step == m_cfg.max_track_candidates_per_track - 1) {
309+
break;
310+
}
311+
309312
if (n_candidates > 0) {
310313
/*****************************************************************
311314
* Kernel4: Get key and value for parameter sorting
@@ -362,6 +365,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
362365
.params_view = in_params_buffer,
363366
.params_liveness_view = param_liveness_buffer,
364367
.param_ids_view = param_ids_buffer,
368+
.links_view = links_buffer,
365369
.prev_links_idx = step_to_link_idx_map[step],
366370
.step = step,
367371
.n_in_params = n_candidates,

device/sycl/src/finding/find_tracks.hpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,10 @@ track_candidate_container_types::buffer find_tracks(
325325
n_candidates =
326326
step_to_link_idx_map[step + 1] - step_to_link_idx_map[step];
327327

328+
if (step == config.max_track_candidates_per_track - 1) {
329+
break;
330+
}
331+
328332
if (n_candidates > 0) {
329333
/*****************************************************************
330334
* Kernel4: Get key and value for parameter sorting
@@ -396,6 +400,7 @@ track_candidate_container_types::buffer find_tracks(
396400
param_liveness =
397401
vecmem::get_data(param_liveness_buffer),
398402
param_ids = vecmem::get_data(param_ids_buffer),
403+
links_view = vecmem::get_data(links_buffer),
399404
prev_links_idx = step_to_link_idx_map[step], step,
400405
n_candidates, tips = vecmem::get_data(tips_buffer)](
401406
::sycl::nd_item<1> item) {
@@ -404,8 +409,8 @@ track_candidate_container_types::buffer find_tracks(
404409
typename stepper_t::magnetic_field_type>(
405410
details::global_index(item), config,
406411
{det, field, in_params, param_liveness,
407-
param_ids, prev_links_idx, step, n_candidates,
408-
tips});
412+
param_ids, links_view, prev_links_idx, step,
413+
n_candidates, tips});
409414
});
410415
})
411416
.wait_and_throw();

0 commit comments

Comments
 (0)