Skip to content

Commit 369caa9

Browse files
tsulaiavVakho Tsulaia
and
Vakho Tsulaia
authored
New CPU example to demonstrate the context-aware transform store usage (#937)
New file: `examples/run/cpu/misaligned_truth_fitting_example.cpp`. Also applied several patches to allow passing the correct geometry context to detray Co-authored-by: Vakho Tsulaia <[email protected]>
1 parent 5779b47 commit 369caa9

File tree

5 files changed

+233
-11
lines changed

5 files changed

+233
-11
lines changed

core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ class kalman_fitter {
208208

209209
// Create propagator state
210210
typename forward_propagator_type::state propagation(
211-
seed_params, m_field, m_detector);
211+
seed_params, m_field, m_detector, m_cfg.propagation.context);
212212
propagation.set_particle(detail::correct_particle_hypothesis(
213213
m_cfg.ptc_hypothesis, seed_params));
214214

@@ -295,7 +295,7 @@ class kalman_fitter {
295295

296296
typename backward_propagator_type::state propagation(
297297
last.smoothed(), m_field, m_detector,
298-
fitter_state.m_sequence_buffer);
298+
fitter_state.m_sequence_buffer, backward_cfg.context);
299299
propagation.set_particle(detail::correct_particle_hypothesis(
300300
m_cfg.ptc_hypothesis, last.smoothed()));
301301

core/include/traccc/utils/seed_generator.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ namespace traccc {
2626
template <typename detector_t>
2727
struct seed_generator {
2828
using algebra_type = typename detector_t::algebra_type;
29-
using cxt_t = typename detector_t::geometry_context;
29+
using ctx_t = typename detector_t::geometry_context;
3030

3131
/// Constructor with detector
3232
///
3333
/// @param det input detector
3434
/// @param stddevs standard deviations for parameter smearing
3535
seed_generator(const detector_t& det,
3636
const std::array<scalar, e_bound_size>& stddevs,
37-
const std::size_t sd = 0)
38-
: m_detector(det), m_stddevs(stddevs) {
37+
const std::size_t sd = 0, ctx_t ctx = {})
38+
: m_detector(det), m_stddevs(stddevs), m_ctx(ctx) {
3939
m_generator.seed(static_cast<std::mt19937::result_type>(sd));
4040
}
4141

@@ -51,8 +51,7 @@ struct seed_generator {
5151
// Get bound parameter
5252
const detray::tracking_surface sf{m_detector, surface_link};
5353

54-
const cxt_t ctx{};
55-
auto bound_vec = sf.free_to_bound_vector(ctx, free_param);
54+
auto bound_vec = sf.free_to_bound_vector(m_ctx, free_param);
5655
auto bound_cov = matrix::zero<traccc::bound_matrix<algebra_type>>();
5756

5857
bound_track_parameters<algebra_type> bound_param{surface_link,
@@ -68,7 +67,7 @@ struct seed_generator {
6867
typename interactor_type::state interactor_state;
6968
interactor_state.do_multiple_scattering = false;
7069
interactor_type{}.update(
71-
ctx, ptc_type, bound_param, interactor_state,
70+
m_ctx, ptc_type, bound_param, interactor_state,
7271
static_cast<int>(detray::navigation::direction::e_backward), sf);
7372

7473
for (std::size_t i = 0; i < e_bound_size; i++) {
@@ -94,6 +93,7 @@ struct seed_generator {
9493
const detector_t& m_detector;
9594
/// Standard deviations for parameter smearing
9695
std::array<scalar, e_bound_size> m_stddevs;
96+
ctx_t m_ctx;
9797
};
9898

9999
} // namespace traccc

examples/run/cpu/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ traccc_add_executable( truth_fitting_example "truth_fitting_example.cpp"
2020
LINK_LIBRARIES vecmem::core detray::io detray::detectors traccc::core
2121
traccc::io traccc::performance traccc::options)
2222

23+
traccc_add_executable( misaligned_truth_fitting_example "misaligned_truth_fitting_example.cpp"
24+
LINK_LIBRARIES vecmem::core detray::io detray::detectors traccc::core
25+
traccc::io traccc::performance traccc::options)
26+
2327
#
2428
# Set up the "throughput applications".
2529
#
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2023-2025 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
// Project include(s).
9+
#include "traccc/definitions/common.hpp"
10+
#include "traccc/definitions/primitives.hpp"
11+
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
12+
#include "traccc/geometry/detector.hpp"
13+
#include "traccc/io/read_geometry.hpp"
14+
#include "traccc/io/utils.hpp"
15+
#include "traccc/options/detector.hpp"
16+
#include "traccc/options/input_data.hpp"
17+
#include "traccc/options/performance.hpp"
18+
#include "traccc/options/program_options.hpp"
19+
#include "traccc/options/track_fitting.hpp"
20+
#include "traccc/options/track_propagation.hpp"
21+
#include "traccc/resolution/fitting_performance_writer.hpp"
22+
#include "traccc/utils/seed_generator.hpp"
23+
24+
// Detray include(s).
25+
#include <detray/core/detector.hpp>
26+
#include <detray/detectors/bfield.hpp>
27+
#include <detray/io/frontend/detector_reader.hpp>
28+
#include <detray/navigation/navigator.hpp>
29+
#include <detray/propagator/propagator.hpp>
30+
#include <detray/propagator/rk_stepper.hpp>
31+
32+
// VecMem include(s).
33+
#include <vecmem/memory/host_memory_resource.hpp>
34+
#include <vecmem/utils/copy.hpp>
35+
36+
// System include(s).
37+
#include <cstdlib>
38+
#include <exception>
39+
#include <iomanip>
40+
#include <iostream>
41+
42+
using namespace traccc;
43+
namespace po = boost::program_options;
44+
45+
// The main routine
46+
//
47+
int main(int argc, char* argv[]) {
48+
std::unique_ptr<const traccc::Logger> ilogger = traccc::getDefaultLogger(
49+
"TracccExampleTruthFitting", traccc::Logging::Level::INFO);
50+
51+
TRACCC_LOCAL_LOGGER(std::move(ilogger));
52+
53+
// Program options.
54+
traccc::opts::detector detector_opts;
55+
traccc::opts::input_data input_opts;
56+
traccc::opts::track_propagation propagation_opts;
57+
traccc::opts::track_fitting fitting_opts;
58+
traccc::opts::performance performance_opts;
59+
traccc::opts::program_options program_opts{
60+
"Truth Track Fitting on the Host",
61+
{detector_opts, input_opts, propagation_opts, fitting_opts,
62+
performance_opts},
63+
argc,
64+
argv,
65+
logger().cloneWithSuffix("Options")};
66+
67+
/// Type declarations
68+
using host_detector_type = traccc::default_detector::host;
69+
70+
// Memory resources used by the application.
71+
vecmem::host_memory_resource host_mr;
72+
// Copy obejct
73+
vecmem::copy copy;
74+
75+
// Performance writer
76+
traccc::fitting_performance_writer fit_performance_writer(
77+
traccc::fitting_performance_writer::config{});
78+
79+
/*****************************
80+
* Build a geometry
81+
*****************************/
82+
83+
// B field value and its type
84+
// @TODO: Set B field as argument
85+
const traccc::vector3 B{0, 0, 2 * traccc::unit<traccc::scalar>::T};
86+
auto field = detray::bfield::create_const_field<traccc::scalar>(B);
87+
88+
// Read the detector
89+
detray::io::detector_reader_config reader_cfg{};
90+
reader_cfg.add_file(
91+
traccc::io::get_absolute_path(detector_opts.detector_file));
92+
if (!detector_opts.material_file.empty()) {
93+
reader_cfg.add_file(
94+
traccc::io::get_absolute_path(detector_opts.material_file));
95+
}
96+
if (!detector_opts.grid_file.empty()) {
97+
reader_cfg.add_file(
98+
traccc::io::get_absolute_path(detector_opts.grid_file));
99+
}
100+
const auto [host_det, names] =
101+
detray::io::read_detector<host_detector_type>(host_mr, reader_cfg);
102+
103+
/// Create a "misaligned" context in the transform store
104+
using xf_container = host_detector_type::transform_container;
105+
using xf_vector = xf_container::base_type;
106+
107+
const xf_container& default_xfs = host_det.transform_store();
108+
xf_vector misaligned_xfs;
109+
misaligned_xfs.reserve(default_xfs.size());
110+
for (const auto& xf : default_xfs) {
111+
misaligned_xfs.push_back(xf);
112+
}
113+
xf_container* ptr_default_xfs = const_cast<xf_container*>(&default_xfs);
114+
ptr_default_xfs->add_context(misaligned_xfs);
115+
116+
/*****************************
117+
* Do the reconstruction
118+
*****************************/
119+
120+
/// Standard deviations for seed track parameters
121+
static constexpr std::array<scalar, e_bound_size> stddevs = {
122+
0.03f * traccc::unit<scalar>::mm,
123+
0.03f * traccc::unit<scalar>::mm,
124+
0.017f,
125+
0.017f,
126+
0.001f / traccc::unit<scalar>::GeV,
127+
1.f * traccc::unit<scalar>::ns};
128+
129+
// Fitting algorithm objects
130+
// Alg0
131+
traccc::fitting_config fit_cfg0(fitting_opts);
132+
fit_cfg0.propagation = propagation_opts;
133+
fit_cfg0.propagation.context = detray::geometry_context{0};
134+
traccc::host::kalman_fitting_algorithm host_fitting0(
135+
fit_cfg0, host_mr, copy, logger().clone("FittingAlg0"));
136+
// Alg1
137+
traccc::fitting_config fit_cfg1(fitting_opts);
138+
fit_cfg1.propagation = propagation_opts;
139+
fit_cfg1.propagation.context = detray::geometry_context{1};
140+
traccc::host::kalman_fitting_algorithm host_fitting1(
141+
fit_cfg1, host_mr, copy, logger().clone("FittingAlg1"));
142+
143+
// Seed generators
144+
traccc::seed_generator<host_detector_type> sg0(
145+
host_det, stddevs, 0, fit_cfg0.propagation.context);
146+
traccc::seed_generator<host_detector_type> sg1(
147+
host_det, stddevs, 0, fit_cfg1.propagation.context);
148+
149+
// Iterate over events
150+
for (auto event = input_opts.skip;
151+
event < input_opts.events + input_opts.skip; ++event) {
152+
153+
// Truth Track Candidates
154+
traccc::event_data evt_data(input_opts.directory, event, host_mr,
155+
input_opts.use_acts_geom_source, &host_det,
156+
input_opts.format, false);
157+
158+
// For the first half of events run Alg0
159+
if ((event - input_opts.skip) / (input_opts.events / 2) == 0) {
160+
traccc::track_candidate_container_types::host
161+
truth_track_candidates =
162+
evt_data.generate_truth_candidates(sg0, host_mr);
163+
164+
// Run fitting
165+
auto track_states = host_fitting0(
166+
host_det, field, traccc::get_data(truth_track_candidates));
167+
168+
print_fitted_tracks_statistics(track_states);
169+
170+
const decltype(track_states)::size_type n_fitted_tracks =
171+
track_states.size();
172+
173+
if (performance_opts.run) {
174+
175+
for (unsigned int i = 0; i < n_fitted_tracks; i++) {
176+
const auto& trk_states_per_track = track_states.at(i).items;
177+
178+
const auto& fit_res = track_states[i].header;
179+
180+
fit_performance_writer.write(trk_states_per_track, fit_res,
181+
host_det, evt_data,
182+
fit_cfg0.propagation.context);
183+
}
184+
}
185+
} else {
186+
traccc::track_candidate_container_types::host
187+
truth_track_candidates =
188+
evt_data.generate_truth_candidates(sg1, host_mr);
189+
190+
// Run fitting
191+
auto track_states = host_fitting1(
192+
host_det, field, traccc::get_data(truth_track_candidates));
193+
194+
print_fitted_tracks_statistics(track_states);
195+
196+
const decltype(track_states)::size_type n_fitted_tracks =
197+
track_states.size();
198+
199+
if (performance_opts.run) {
200+
201+
for (unsigned int i = 0; i < n_fitted_tracks; i++) {
202+
const auto& trk_states_per_track = track_states.at(i).items;
203+
204+
const auto& fit_res = track_states[i].header;
205+
206+
fit_performance_writer.write(trk_states_per_track, fit_res,
207+
host_det, evt_data,
208+
fit_cfg1.propagation.context);
209+
}
210+
}
211+
}
212+
}
213+
214+
if (performance_opts.run) {
215+
fit_performance_writer.finalize();
216+
}
217+
218+
return EXIT_SUCCESS;
219+
}

performance/include/traccc/resolution/fitting_performance_writer.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class fitting_performance_writer {
5656
template <typename detector_t>
5757
void write(const track_state_collection_types::host& track_states_per_track,
5858
const fitting_result<traccc::default_algebra>& fit_res,
59-
const detector_t& det, event_data& evt_data) {
59+
const detector_t& det, event_data& evt_data,
60+
const detector_t::geometry_context& ctx = {}) {
6061

6162
static_assert(std::same_as<typename detector_t::algebra_type,
6263
traccc::default_algebra>);
@@ -97,8 +98,6 @@ class fitting_performance_writer {
9798
const auto global_mom = meas_to_param_map.at(meas).second;
9899

99100
const detray::tracking_surface sf{det, meas.surface_link};
100-
using cxt_t = typename detector_t::geometry_context;
101-
const cxt_t ctx{};
102101
const auto truth_bound =
103102
sf.global_to_bound(ctx, global_pos, vector::normalize(global_mom));
104103

0 commit comments

Comments
 (0)