Skip to content

Commit 791fbfe

Browse files
committed
Cut by ..._track_candidates_per_track as soon as possible
1 parent 2eb68d8 commit 791fbfe

19 files changed

+98
-409
lines changed

device/alpaka/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ traccc_add_alpaka_library( traccc_alpaka alpaka TYPE SHARED
5252
"src/finding/kernels/make_barcode_sequence.hpp"
5353
"src/finding/kernels/apply_interaction.hpp"
5454
"src/finding/kernels/fill_sort_keys.hpp"
55-
"src/finding/kernels/prune_tracks.hpp"
5655
"src/finding/kernels/build_tracks.hpp"
5756
"src/finding/kernels/find_tracks.hpp"
5857
"src/finding/kernels/propagate_to_next_surface.hpp"

device/alpaka/src/finding/finding_algorithm.cpp

+7-54
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "./kernels/find_tracks.hpp"
1818
#include "./kernels/make_barcode_sequence.hpp"
1919
#include "./kernels/propagate_to_next_surface.hpp"
20-
#include "./kernels/prune_tracks.hpp"
2120
#include "traccc/definitions/primitives.hpp"
2221
#include "traccc/definitions/qualifiers.hpp"
2322
#include "traccc/edm/device/sort_key.hpp"
@@ -306,6 +305,10 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
306305
::alpaka::wait(queue);
307306
}
308307

308+
if (step == m_cfg.max_track_candidates_per_track - 1) {
309+
break;
310+
}
311+
309312
if (n_candidates > 0) {
310313
/*****************************************************************
311314
* Kernel4: Get key and value for parameter sorting
@@ -413,25 +416,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
413416
m_copy.setup(track_candidates_buffer.headers)->ignore();
414417
m_copy.setup(track_candidates_buffer.items)->ignore();
415418

416-
// Create buffer for valid indices
417-
vecmem::data::vector_buffer<unsigned int> valid_indices_buffer(n_tips_total,
418-
m_mr.main);
419-
420-
// Count the number of valid tracks
421-
auto bufHost_n_valid_tracks =
422-
::alpaka::allocBuf<unsigned int, Idx>(devHost, 1u);
423-
unsigned int* n_valid_tracks =
424-
::alpaka::getPtrNative(bufHost_n_valid_tracks);
425-
::alpaka::memset(queue, bufHost_n_valid_tracks, 0);
426-
::alpaka::wait(queue);
427-
428419
// @Note: nBlocks can be zero in case there is no tip. This happens when
429420
// chi2_max config is set tightly and no tips are found
430421
if (n_tips_total > 0) {
431-
auto n_valid_tracks_device =
432-
::alpaka::allocBuf<unsigned int, Idx>(devAcc, 1u);
433-
::alpaka::memset(queue, n_valid_tracks_device, 0);
434-
435422
Idx blocksPerGrid =
436423
(n_tips_total + threadsPerBlock - 1) / threadsPerBlock;
437424
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
@@ -440,49 +427,15 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
440427
track_candidates_buffer);
441428

442429
::alpaka::exec<Acc>(
443-
queue, workDiv, BuildTracksKernel{}, m_cfg,
430+
queue, workDiv, BuildTracksKernel{},
444431
device::build_tracks_payload{
445432
measurements, vecmem::get_data(seeds_buffer),
446433
vecmem::get_data(links_buffer), vecmem::get_data(tips_buffer),
447-
track_candidates_view, vecmem::get_data(valid_indices_buffer),
448-
::alpaka::getPtrNative(n_valid_tracks_device)});
449-
::alpaka::wait(queue);
450-
451-
// Global counter object: Device -> Host
452-
::alpaka::memcpy(queue, bufHost_n_valid_tracks, n_valid_tracks_device);
453-
::alpaka::wait(queue);
454-
}
455-
456-
// Create pruned candidate buffer
457-
track_candidate_container_types::buffer prune_candidates_buffer{
458-
{*n_valid_tracks, m_mr.main},
459-
{std::vector<std::size_t>(*n_valid_tracks,
460-
m_cfg.max_track_candidates_per_track),
461-
m_mr.main, m_mr.host, vecmem::data::buffer_type::resizable}};
462-
463-
m_copy.setup(prune_candidates_buffer.headers)->ignore();
464-
m_copy.setup(prune_candidates_buffer.items)->ignore();
465-
466-
if (*n_valid_tracks > 0) {
467-
Idx blocksPerGrid =
468-
(*n_valid_tracks + threadsPerBlock - 1) / threadsPerBlock;
469-
auto workDiv = makeWorkDiv<Acc>(blocksPerGrid, threadsPerBlock);
470-
471-
track_candidate_container_types::const_view track_candidates_view(
472-
track_candidates_buffer);
473-
474-
track_candidate_container_types::view prune_candidates_view(
475-
prune_candidates_buffer);
476-
477-
::alpaka::exec<Acc>(
478-
queue, workDiv, PruneTracksKernel{},
479-
device::prune_tracks_payload{track_candidates_view,
480-
vecmem::get_data(valid_indices_buffer),
481-
prune_candidates_view});
434+
track_candidates_view});
482435
::alpaka::wait(queue);
483436
}
484437

485-
return prune_candidates_buffer;
438+
return track_candidates_buffer;
486439
}
487440

488441
// Explicit template instantiation

device/alpaka/src/finding/kernels/build_tracks.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ namespace traccc::alpaka {
1919

2020
struct BuildTracksKernel {
2121
template <typename TAcc>
22-
ALPAKA_FN_ACC void operator()(TAcc const& acc, const finding_config cfg,
22+
ALPAKA_FN_ACC void operator()(TAcc const& acc,
2323
device::build_tracks_payload payload) const {
2424

2525
device::global_index_t globalThreadIdx =
2626
::alpaka::getIdx<::alpaka::Grid, ::alpaka::Threads>(acc)[0];
2727

28-
device::build_tracks(globalThreadIdx, cfg, payload);
28+
device::build_tracks(globalThreadIdx, payload);
2929
}
3030
};
3131

device/alpaka/src/finding/kernels/prune_tracks.hpp

-28
This file was deleted.

device/common/CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,12 @@ traccc_add_library( traccc_device_common device_common TYPE INTERFACE
6464
"include/traccc/finding/device/fill_sort_keys.hpp"
6565
"include/traccc/finding/device/make_barcode_sequence.hpp"
6666
"include/traccc/finding/device/propagate_to_next_surface.hpp"
67-
"include/traccc/finding/device/prune_tracks.hpp"
6867
"include/traccc/finding/device/impl/apply_interaction.ipp"
6968
"include/traccc/finding/device/impl/build_tracks.ipp"
7069
"include/traccc/finding/device/impl/find_tracks.ipp"
7170
"include/traccc/finding/device/impl/fill_sort_keys.ipp"
7271
"include/traccc/finding/device/impl/make_barcode_sequence.ipp"
7372
"include/traccc/finding/device/impl/propagate_to_next_surface.ipp"
74-
"include/traccc/finding/device/impl/prune_tracks.ipp"
7573
# Track fitting funtions(s).
7674
"include/traccc/fitting/device/fit.hpp"
7775
"include/traccc/fitting/device/impl/fit.ipp"

device/common/include/traccc/finding/device/build_tracks.hpp

+1-13
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,6 @@ struct build_tracks_payload {
5252
* @brief View object to the vector of track candidates
5353
*/
5454
track_candidate_container_types::view track_candidates_view;
55-
56-
/**
57-
* @brief View object to the vector of indices meeting the selection
58-
* criteria
59-
*/
60-
vecmem::data::vector_view<unsigned int> valid_indices_view;
61-
62-
/**
63-
* @brief The number of valid tracks meeting criteria
64-
*/
65-
unsigned int* n_valid_tracks;
6655
};
6756

6857
/// Function for building full tracks from the link container:
@@ -75,8 +64,7 @@ struct build_tracks_payload {
7564
/// @param[inout] payload The function call payload
7665
///
7766
TRACCC_HOST_DEVICE inline void build_tracks(
78-
global_index_t globalIndex, const finding_config& cfg,
79-
const build_tracks_payload& payload);
67+
global_index_t globalIndex, const build_tracks_payload& payload);
8068

8169
} // namespace traccc::device
8270

device/common/include/traccc/finding/device/impl/build_tracks.ipp

+13-38
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
namespace traccc::device {
1414

1515
TRACCC_HOST_DEVICE inline void build_tracks(
16-
const global_index_t globalIndex, const finding_config& cfg,
17-
const build_tracks_payload& payload) {
16+
const global_index_t globalIndex, const build_tracks_payload& payload) {
1817

1918
const measurement_collection_types::const_device measurements(
2019
payload.measurements_view);
@@ -29,9 +28,6 @@ TRACCC_HOST_DEVICE inline void build_tracks(
2928
track_candidate_container_types::device track_candidates(
3029
payload.track_candidates_view);
3130

32-
vecmem::device_vector<unsigned int> valid_indices(
33-
payload.valid_indices_view);
34-
3531
if (globalIndex >= tips.size()) {
3632
return;
3733
}
@@ -50,8 +46,6 @@ TRACCC_HOST_DEVICE inline void build_tracks(
5046
// Resize the candidates with the exact size
5147
cands_per_track.resize(n_cands);
5248

53-
bool success = true;
54-
5549
// Track summary variables
5650
scalar ndf_sum = 0.f;
5751
scalar chi2_sum = 0.f;
@@ -67,11 +61,7 @@ TRACCC_HOST_DEVICE inline void build_tracks(
6761
L = links.at(L.previous_candidate_idx);
6862
}
6963

70-
// Break if the measurement is still invalid
71-
if (L.meas_idx >= measurements.size()) {
72-
success = false;
73-
break;
74-
}
64+
assert(L.meas_idx < n_meas);
7565

7666
*it = {measurements.at(L.meas_idx)};
7767
num_inserted++;
@@ -97,36 +87,21 @@ TRACCC_HOST_DEVICE inline void build_tracks(
9787
}
9888

9989
#ifndef NDEBUG
100-
if (success) {
101-
// Assert that we inserted exactly as many elements as we reserved
102-
// space for.
103-
assert(num_inserted == cands_per_track.size());
104-
105-
// Assert that we did not make any duplicate track states.
106-
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
107-
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
108-
if (i != j) {
109-
// TODO: Re-enable me!
110-
// assert(cands_per_track.at(i).measurement_id !=
111-
// cands_per_track.at(j).measurement_id);
112-
}
90+
// Assert that we inserted exactly as many elements as we reserved
91+
// space for.
92+
assert(num_inserted == cands_per_track.size());
93+
94+
// Assert that we did not make any duplicate track states.
95+
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
96+
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
97+
if (i != j) {
98+
// TODO: Re-enable me!
99+
// assert(cands_per_track.at(i).measurement_id !=
100+
// cands_per_track.at(j).measurement_id);
113101
}
114102
}
115103
}
116104
#endif
117-
118-
// NOTE: We may at some point want to assert that `success` is true
119-
120-
// Criteria for valid tracks
121-
if (n_cands >= cfg.min_track_candidates_per_track &&
122-
n_cands <= cfg.max_track_candidates_per_track && success) {
123-
124-
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
125-
*payload.n_valid_tracks);
126-
127-
const unsigned int pos = num_valid_tracks.fetch_add(1);
128-
valid_indices[pos] = globalIndex;
129-
}
130105
}
131106

132107
} // namespace traccc::device

device/common/include/traccc/finding/device/impl/find_tracks.ipp

+23-7
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ TRACCC_HOST_DEVICE inline void find_tracks(
7777

7878
const unsigned int in_param_id = thread_id.getGlobalThreadIdX();
7979

80+
const bool last_step =
81+
payload.step == cfg.max_track_candidates_per_track - 1;
82+
8083
/*
8184
* Step 1 of this kernel is to determine which indices belong to which
8285
* parameter. Because the measurements are guaranteed to be grouped, we can
@@ -262,10 +265,19 @@ TRACCC_HOST_DEVICE inline void find_tracks(
262265
.seed_idx = seed_idx,
263266
.n_skipped = n_skipped,
264267
.chi2 = chi2};
265-
266268
out_params.at(l_pos - payload.curr_links_idx) =
267269
trk_state.filtered();
268-
out_params_liveness.at(l_pos - payload.curr_links_idx) = 1u;
270+
out_params_liveness.at(l_pos - payload.curr_links_idx) =
271+
static_cast<unsigned int>(!last_step);
272+
273+
const unsigned int n_cands = payload.step + 1 - n_skipped;
274+
275+
// If no more CKF step is expected, current candidate is kept as
276+
// a tip
277+
if (last_step &&
278+
n_cands >= cfg.min_track_candidates_per_track) {
279+
tips.push_back(l_pos);
280+
}
269281
}
270282
}
271283

@@ -312,11 +324,15 @@ TRACCC_HOST_DEVICE inline void find_tracks(
312324
const unsigned int n_skipped =
313325
payload.step == 0 ? 0 : links.at(prev_link_idx).n_skipped;
314326

315-
if (n_skipped >= cfg.max_num_skipping_per_cand) {
316-
// In case of max skipping being 0 and first step being skipped,
317-
// the links are empty, and the tip has nowhere to point
318-
assert(payload.step > 0);
319-
tips.push_back(prev_link_idx);
327+
if (n_skipped >= cfg.max_num_skipping_per_cand || last_step) {
328+
const unsigned int n_cands = payload.step - n_skipped;
329+
if (n_cands >= cfg.min_track_candidates_per_track) {
330+
// In case of max skipping and min length being 0, and first
331+
// step being skipped, the links are empty, and the tip has
332+
// nowhere to point
333+
assert(payload.step > 0);
334+
tips.push_back(prev_link_idx);
335+
}
320336
} else {
321337
// Add measurement candidates to link
322338
const unsigned int l_pos = links.bulk_append_implicit(1);

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

+11-9
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
3232

3333
const unsigned int param_id = param_ids.at(globalIndex);
3434

35+
// Links
36+
vecmem::device_vector<const candidate_link> links(payload.links_view);
37+
38+
const unsigned int link_idx = payload.prev_links_idx + param_id;
39+
const auto& link = links.at(link_idx);
40+
assert(link.step == payload.step);
41+
const unsigned int n_cands = link.step + 1 - link.n_skipped;
42+
3543
// Parameter liveness
3644
vecmem::device_vector<unsigned int> params_liveness(
3745
payload.params_liveness_view);
@@ -93,18 +101,12 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface(
93101
assert(propagation._navigation.is_on_sensitive());
94102

95103
params[param_id] = propagation._stepping.bound_params();
96-
97-
if (payload.step == cfg.max_track_candidates_per_track - 1) {
98-
tips.push_back(payload.prev_links_idx + param_id);
99-
params_liveness[param_id] = 0u;
100-
} else {
101-
params_liveness[param_id] = 1u;
102-
}
104+
params_liveness[param_id] = 1u;
103105
} else {
104106
params_liveness[param_id] = 0u;
105107

106-
if (payload.step >= cfg.min_track_candidates_per_track - 1) {
107-
tips.push_back(payload.prev_links_idx + param_id);
108+
if (n_cands >= cfg.min_track_candidates_per_track) {
109+
tips.push_back(link_idx);
108110
}
109111
}
110112
}

0 commit comments

Comments
 (0)