Skip to content

Commit 3aa466b

Browse files
committed
Improve memory usage in track finding postamble
The code which turns the tips of our track finding into actual tracks uses an excessive amount of memory, as it massively overallocates. Indeed, it allocates memory as though all tips have the maximum number of track states, which is unrealistic. This commit makes it so that the number of valid track states is counted instead, making the allocation more precise. In my measurements, this more than halves the memory usage of traccc on $\langle\mu\rangle = 200$ ttbar events in the ODD.
1 parent 6fee5ca commit 3aa466b

12 files changed

+190
-395
lines changed

device/common/CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,12 @@ traccc_add_library( traccc_device_common device_common TYPE INTERFACE
6464
"include/traccc/finding/device/fill_sort_keys.hpp"
6565
"include/traccc/finding/device/make_barcode_sequence.hpp"
6666
"include/traccc/finding/device/propagate_to_next_surface.hpp"
67-
"include/traccc/finding/device/prune_tracks.hpp"
6867
"include/traccc/finding/device/impl/apply_interaction.ipp"
6968
"include/traccc/finding/device/impl/build_tracks.ipp"
7069
"include/traccc/finding/device/impl/find_tracks.ipp"
7170
"include/traccc/finding/device/impl/fill_sort_keys.ipp"
7271
"include/traccc/finding/device/impl/make_barcode_sequence.ipp"
7372
"include/traccc/finding/device/impl/propagate_to_next_surface.ipp"
74-
"include/traccc/finding/device/impl/prune_tracks.ipp"
7573
# Track fitting funtions(s).
7674
"include/traccc/fitting/device/fit.hpp"
7775
"include/traccc/fitting/device/impl/fit.ipp"

device/common/include/traccc/finding/device/build_tracks.hpp

+2-13
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,9 @@ struct build_tracks_payload {
5050
tips_view;
5151

5252
/**
53-
* @brief View object to the vector of track candidates
53+
* @brief View object to the vector of pruned track candidates
5454
*/
55-
track_candidate_container_types::view track_candidates_view;
56-
57-
/**
58-
* @brief View object to the vector of indices meeting the selection
59-
* criteria
60-
*/
61-
vecmem::data::vector_view<unsigned int> valid_indices_view;
62-
63-
/**
64-
* @brief The number of valid tracks meeting criteria
65-
*/
66-
unsigned int* n_valid_tracks;
55+
track_candidate_container_types::view final_candidates_view;
6756
};
6857

6958
/// Function for building full tracks from the link container:

device/common/include/traccc/finding/device/impl/build_tracks.ipp

+28-73
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#pragma once
99

10+
#include <vecmem/containers/device_vector.hpp>
1011
namespace traccc::device {
1112

1213
TRACCC_DEVICE inline void build_tracks(const global_index_t globalIndex,
@@ -25,59 +26,32 @@ TRACCC_DEVICE inline void build_tracks(const global_index_t globalIndex,
2526
const vecmem::device_vector<const typename candidate_link::link_index_type>
2627
tips(payload.tips_view);
2728

28-
track_candidate_container_types::device track_candidates(
29-
payload.track_candidates_view);
30-
31-
vecmem::device_vector<unsigned int> valid_indices(
32-
payload.valid_indices_view);
29+
track_candidate_container_types::device final_candidates(
30+
payload.final_candidates_view);
3331

3432
if (globalIndex >= tips.size()) {
3533
return;
3634
}
3735

3836
const auto tip = tips.at(globalIndex);
39-
auto& seed = track_candidates[globalIndex].header.seed_params;
40-
auto& trk_quality = track_candidates[globalIndex].header.trk_quality;
41-
auto cands_per_track = track_candidates[globalIndex].items;
4237

4338
// Get the link corresponding to tip
4439
auto L = links[tip.first][tip.second];
4540
const unsigned int n_meas = measurements.size();
4641

47-
// Count the number of skipped steps
48-
unsigned int n_skipped{0u};
49-
while (true) {
50-
if (L.meas_idx > n_meas) {
51-
n_skipped++;
52-
}
53-
54-
if (L.previous.first == 0u) {
55-
break;
56-
}
57-
58-
L = links[L.previous.first][L.previous.second];
59-
}
60-
61-
// Retrieve tip
62-
L = links[tip.first][tip.second];
63-
64-
const unsigned int n_cands = tip.first + 1 - n_skipped;
65-
66-
// Resize the candidates with the exact size
67-
cands_per_track.resize(n_cands);
42+
const unsigned int num_meas = final_candidates.at(globalIndex).items.size();
6843

69-
bool success = true;
44+
assert(num_meas >= cfg.min_track_candidates_per_track &&
45+
num_meas <= cfg.max_track_candidates_per_track);
7046

7147
// Track summary variables
7248
scalar ndf_sum = 0.f;
7349
scalar chi2_sum = 0.f;
74-
75-
[[maybe_unused]] std::size_t num_inserted = 0;
50+
unsigned int final_n_skipped = 0;
51+
unsigned int seed_idx = 0;
7652

7753
// Reversely iterate to fill the track candidates
78-
for (auto it = cands_per_track.rbegin(); it != cands_per_track.rend();
79-
it++) {
80-
54+
for (unsigned int i = num_meas - 1; i < num_meas; --i) {
8155
while (L.meas_idx >= n_meas &&
8256
L.previous.first !=
8357
std::numeric_limits<
@@ -86,64 +60,45 @@ TRACCC_DEVICE inline void build_tracks(const global_index_t globalIndex,
8660
L = links[L.previous.first][L.previous.second];
8761
}
8862

89-
// Break if the measurement is still invalid
90-
if (L.meas_idx >= measurements.size()) {
91-
success = false;
92-
break;
93-
}
63+
assert(L.meas_idx < measurements.size());
9464

95-
*it = {measurements.at(L.meas_idx)};
96-
num_inserted++;
65+
final_candidates.at(globalIndex).items.at(i) =
66+
measurements.at(L.meas_idx);
9767

9868
// Sanity check on chi2
9969
assert(L.chi2 < std::numeric_limits<traccc::scalar>::max());
10070
assert(L.chi2 >= 0.f);
10171

102-
ndf_sum += static_cast<scalar>(it->meas_dim);
72+
ndf_sum += static_cast<scalar>(measurements.at(L.meas_idx).meas_dim);
10373
chi2_sum += L.chi2;
10474

10575
// Break the loop if the iterator is at the first candidate and fill the
10676
// seed and track quality
107-
if (it == cands_per_track.rend() - 1) {
108-
seed = seeds.at(L.previous.second);
109-
trk_quality.ndf = ndf_sum - 5.f;
110-
trk_quality.chi2 = chi2_sum;
111-
trk_quality.n_holes = L.n_skipped;
77+
if (i == 0) {
78+
seed_idx = L.previous.second;
79+
final_n_skipped = L.n_skipped;
11280
} else {
11381
L = links[L.previous.first][L.previous.second];
11482
}
11583
}
11684

85+
final_candidates.at(globalIndex).header.seed_params = seeds.at(seed_idx);
86+
final_candidates.at(globalIndex).header.trk_quality.chi2 = chi2_sum;
87+
final_candidates.at(globalIndex).header.trk_quality.ndf = ndf_sum - 5.f;
88+
final_candidates.at(globalIndex).header.trk_quality.n_holes =
89+
final_n_skipped;
90+
11791
#ifndef NDEBUG
118-
if (success) {
119-
// Assert that we inserted exactly as many elements as we reserved
120-
// space for.
121-
assert(num_inserted == cands_per_track.size());
122-
123-
// Assert that we did not make any duplicate track states.
124-
for (unsigned int i = 0; i < cands_per_track.size(); ++i) {
125-
for (unsigned int j = 0; j < cands_per_track.size(); ++j) {
126-
if (i != j) {
127-
assert(cands_per_track.at(i).measurement_id !=
128-
cands_per_track.at(j).measurement_id);
129-
}
92+
// Assert that we did not make any duplicate track states.
93+
for (unsigned int i = 0; i < num_meas; ++i) {
94+
for (unsigned int j = 0; j < num_meas; ++j) {
95+
if (i != j) {
96+
// TODO: Reenable
97+
// assert(meas_indxs[i] != meas_indxs[j]);
13098
}
13199
}
132100
}
133101
#endif
134-
135-
// NOTE: We may at some point want to assert that `success` is true
136-
137-
// Criteria for valid tracks
138-
if (n_cands >= cfg.min_track_candidates_per_track &&
139-
n_cands <= cfg.max_track_candidates_per_track && success) {
140-
141-
vecmem::device_atomic_ref<unsigned int> num_valid_tracks(
142-
*payload.n_valid_tracks);
143-
144-
const unsigned int pos = num_valid_tracks.fetch_add(1);
145-
valid_indices[pos] = globalIndex;
146-
}
147102
}
148103

149104
} // namespace traccc::device

device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

+82-70
Original file line numberDiff line numberDiff line change
@@ -47,85 +47,97 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
4747
const unsigned int s_pos = num_tracks_per_seed.fetch_add(1);
4848
vecmem::device_vector<unsigned int> params_liveness(
4949
payload.params_liveness_view);
50-
51-
if (s_pos >= cfg.max_num_branches_per_seed) {
52-
params_liveness[param_id] = 0u;
53-
return;
54-
}
55-
56-
// tips
5750
vecmem::device_vector<typename candidate_link::link_index_type> tips(
5851
payload.tips_view);
52+
vecmem::device_vector<unsigned int> tip_lengths(payload.tip_lengths_view);
5953

60-
if (links.at(param_id).n_skipped > cfg.max_num_skipping_per_cand) {
61-
params_liveness[param_id] = 0u;
62-
tips.push_back({payload.step, param_id});
63-
return;
64-
}
65-
66-
// Detector
67-
typename propagator_t::detector_type det(payload.det_data);
68-
69-
// Parameters
70-
bound_track_parameters_collection_types::device params(payload.params_view);
54+
bool create_tip = false;
7155

72-
if (params_liveness.at(param_id) == 0u) {
73-
return;
56+
if (s_pos >= cfg.max_num_branches_per_seed) {
57+
params_liveness.at(param_id) = 0u;
58+
} else if (links.at(param_id).n_skipped > cfg.max_num_skipping_per_cand) {
59+
params_liveness.at(param_id) = 0u;
60+
create_tip = true;
7461
}
7562

76-
// Input bound track parameter
77-
const bound_track_parameters<> in_par = params.at(param_id);
78-
79-
// Create propagator
80-
propagator_t propagator(cfg.propagation);
81-
82-
// Create propagator state
83-
typename propagator_t::state propagation(in_par, payload.field_data, det);
84-
propagation.set_particle(
85-
detail::correct_particle_hypothesis(cfg.ptc_hypothesis, in_par));
86-
propagation._stepping
87-
.template set_constraint<detray::step::constraint::e_accuracy>(
88-
cfg.propagation.stepping.step_constraint);
89-
90-
// Actor state
91-
// @TODO: simplify the syntax here
92-
// @NOTE: Post material interaction might be required here
93-
using actor_list_type =
94-
typename propagator_t::actor_chain_type::actor_list_type;
95-
typename detray::detail::tuple_element<0, actor_list_type>::type::state
96-
s0{};
97-
typename detray::detail::tuple_element<1, actor_list_type>::type::state
98-
s1{};
99-
typename detray::detail::tuple_element<3, actor_list_type>::type::state
100-
s3{};
101-
typename detray::detail::tuple_element<2, actor_list_type>::type::state s2{
102-
s3};
103-
typename detray::detail::tuple_element<4, actor_list_type>::type::state s4;
104-
s4.min_step_length = cfg.min_step_length_for_next_surface;
105-
s4.max_count = cfg.max_step_counts_for_next_surface;
106-
107-
// @TODO: Should be removed once detray is fixed to set the volume in the
108-
// constructor
109-
propagation._navigation.set_volume(in_par.surface_link().volume());
110-
111-
// Propagate to the next surface
112-
propagator.propagate_sync(propagation, detray::tie(s0, s1, s2, s3, s4));
113-
114-
// If a surface found, add the parameter for the next step
115-
if (s4.success) {
116-
params[param_id] = propagation._stepping.bound_params();
117-
118-
if (payload.step == cfg.max_track_candidates_per_track - 1) {
119-
tips.push_back({payload.step, param_id});
120-
params_liveness[param_id] = 0u;
63+
if (params_liveness.at(param_id) != 0u) {
64+
// Detector
65+
typename propagator_t::detector_type det(payload.det_data);
66+
67+
// Parameters
68+
bound_track_parameters_collection_types::device params(
69+
payload.params_view);
70+
71+
// Input bound track parameter
72+
const bound_track_parameters<> in_par = params.at(param_id);
73+
74+
// Create propagator
75+
propagator_t propagator(cfg.propagation);
76+
77+
// Create propagator state
78+
typename propagator_t::state propagation(in_par, payload.field_data,
79+
det);
80+
propagation.set_particle(
81+
detail::correct_particle_hypothesis(cfg.ptc_hypothesis, in_par));
82+
propagation._stepping
83+
.template set_constraint<detray::step::constraint::e_accuracy>(
84+
cfg.propagation.stepping.step_constraint);
85+
86+
// Actor state
87+
// @TODO: simplify the syntax here
88+
// @NOTE: Post material interaction might be required here
89+
using actor_list_type =
90+
typename propagator_t::actor_chain_type::actor_list_type;
91+
typename detray::detail::tuple_element<0, actor_list_type>::type::state
92+
s0{};
93+
typename detray::detail::tuple_element<1, actor_list_type>::type::state
94+
s1{};
95+
typename detray::detail::tuple_element<3, actor_list_type>::type::state
96+
s3{};
97+
typename detray::detail::tuple_element<2, actor_list_type>::type::state
98+
s2{s3};
99+
typename detray::detail::tuple_element<4, actor_list_type>::type::state
100+
s4;
101+
s4.min_step_length = cfg.min_step_length_for_next_surface;
102+
s4.max_count = cfg.max_step_counts_for_next_surface;
103+
104+
// @TODO: Should be removed once detray is fixed to set the volume in
105+
// the constructor
106+
propagation._navigation.set_volume(in_par.surface_link().volume());
107+
108+
// Propagate to the next surface
109+
propagator.propagate_sync(propagation, detray::tie(s0, s1, s2, s3, s4));
110+
111+
// If a surface found, add the parameter for the next step
112+
if (s4.success) {
113+
params[param_id] = propagation._stepping.bound_params();
114+
115+
if (payload.step == cfg.max_track_candidates_per_track - 1) {
116+
create_tip = true;
117+
params_liveness[param_id] = 0u;
118+
} else {
119+
params_liveness[param_id] = 1u;
120+
}
121121
} else {
122-
params_liveness[param_id] = 1u;
122+
params_liveness[param_id] = 0u;
123+
124+
if (payload.step >= cfg.min_track_candidates_per_track - 1) {
125+
create_tip = true;
126+
}
123127
}
124-
} else {
125-
params_liveness[param_id] = 0u;
128+
}
129+
130+
if (create_tip) {
131+
const auto& L = links.at(param_id);
132+
133+
const unsigned int num_meas = payload.step + 1 - L.n_skipped;
126134

127-
if (payload.step >= cfg.min_track_candidates_per_track - 1) {
128-
tips.push_back({payload.step, param_id});
135+
// Criteria for valid tracks
136+
if (num_meas >= cfg.min_track_candidates_per_track &&
137+
num_meas <= cfg.max_track_candidates_per_track) {
138+
const unsigned int tip_pos =
139+
tips.push_back({payload.step, param_id});
140+
tip_lengths.at(tip_pos) = num_meas;
129141
}
130142
}
131143
}

0 commit comments

Comments
 (0)