Skip to content

Commit f5506f0

Browse files
committed
Synchronized the SYCL CKF algorithm with the CUDA one.
1 parent 0a17b53 commit f5506f0

File tree

1 file changed

+198
-15
lines changed

1 file changed

+198
-15
lines changed

device/sycl/src/finding/combinatorial_kalman_filter.hpp

Lines changed: 198 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/** TRACCC library, part of the ACTS project (R&D line)
22
*
3-
* (c) 2024-2025 CERN for the benefit of the ACTS project
3+
* (c) 2024-2026 CERN for the benefit of the ACTS project
44
*
55
* Mozilla Public License Version 2.0
66
*/
@@ -29,8 +29,11 @@
2929
#include "traccc/finding/device/fill_finding_duplicate_removal_sort_keys.hpp"
3030
#include "traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
3131
#include "traccc/finding/device/find_tracks.hpp"
32+
#include "traccc/finding/device/gather_best_tips_per_measurement.hpp"
33+
#include "traccc/finding/device/gather_measurement_votes.hpp"
3234
#include "traccc/finding/device/propagate_to_next_surface.hpp"
3335
#include "traccc/finding/device/remove_duplicates.hpp"
36+
#include "traccc/finding/device/update_tip_length_buffer.hpp"
3437
#include "traccc/finding/finding_config.hpp"
3538
#include "traccc/utils/memory_resource.hpp"
3639
#include "traccc/utils/projections.hpp"
@@ -58,6 +61,14 @@ struct fill_finding_propagation_sort_keys {};
5861
template <typename T>
5962
struct propagate_to_next_surface {};
6063
template <typename T>
64+
struct gather_best_tips_per_measurement {};
65+
66+
template <typename T>
67+
struct gather_measurement_votes {};
68+
69+
template <typename T>
70+
struct update_tip_length_buffer {};
71+
template <typename T>
6172
struct build_tracks {};
6273
} // namespace kernels
6374

@@ -184,7 +195,7 @@ combinatorial_kalman_filter(
184195
* finding, we need some space to store the intermediate Jacobians
185196
* and parameters. Allocate that space here.
186197
*/
187-
if (false && config.run_mbf_smoother) {
198+
if (config.run_mbf_smoother) {
188199
jacobian_ptr = vecmem::make_unique_alloc<
189200
bound_matrix<typename detector_t::algebra_type>[]>(
190201
mr.main, link_buffer_capacity);
@@ -273,6 +284,41 @@ combinatorial_kalman_filter(
273284
copy(links_buffer, new_links_buffer)->wait();
274285

275286
links_buffer = std::move(new_links_buffer);
287+
288+
if (config.run_mbf_smoother) {
289+
vecmem::unique_alloc_ptr<
290+
bound_matrix<typename detector_t::algebra_type>[]>
291+
new_jacobian_ptr = vecmem::make_unique_alloc<
292+
bound_matrix<typename detector_t::algebra_type>[]>(
293+
mr.main, link_buffer_capacity);
294+
bound_track_parameters_collection_types::buffer
295+
new_link_predicted_parameter_buffer{link_buffer_capacity,
296+
mr.main};
297+
bound_track_parameters_collection_types::buffer
298+
new_link_filtered_parameter_buffer{link_buffer_capacity,
299+
mr.main};
300+
301+
copy(
302+
vecmem::data::vector_view<
303+
bound_matrix<typename detector_t::algebra_type>>{
304+
links_size, jacobian_ptr.get()},
305+
vecmem::data::vector_view<
306+
bound_matrix<typename detector_t::algebra_type>>{
307+
link_buffer_capacity, new_jacobian_ptr.get()})
308+
->wait();
309+
copy(link_predicted_parameter_buffer,
310+
new_link_predicted_parameter_buffer)
311+
->wait();
312+
copy(link_filtered_parameter_buffer,
313+
new_link_filtered_parameter_buffer)
314+
->wait();
315+
316+
jacobian_ptr = std::move(new_jacobian_ptr);
317+
link_predicted_parameter_buffer =
318+
std::move(new_link_predicted_parameter_buffer);
319+
link_filtered_parameter_buffer =
320+
std::move(new_link_filtered_parameter_buffer);
321+
}
276322
}
277323

278324
{
@@ -503,7 +549,7 @@ combinatorial_kalman_filter(
503549
*****************************************************************/
504550

505551
{
506-
if (false && config.run_mbf_smoother) {
552+
if (config.run_mbf_smoother) {
507553
tmp_jacobian_ptr = vecmem::make_unique_alloc<
508554
bound_matrix<typename detector_t::algebra_type>[]>(
509555
mr.main, n_candidates);
@@ -565,7 +611,136 @@ combinatorial_kalman_filter(
565611
// Get the number of tips
566612
const unsigned int n_tips_total = copy.get_size(tips_buffer);
567613

568-
std::vector<unsigned int> tips_length_host;
614+
vecmem::vector<unsigned int> tips_length_host(mr.host);
615+
vecmem::data::vector_buffer<unsigned int> tip_to_output_map;
616+
617+
unsigned int n_tips_total_filtered = n_tips_total;
618+
619+
if (n_tips_total > 0 && config.max_num_tracks_per_measurement > 0) {
620+
// TODO: DOCS
621+
622+
vecmem::data::vector_buffer<unsigned int>
623+
best_tips_per_measurement_index_buffer(
624+
config.max_num_tracks_per_measurement * n_measurements,
625+
mr.main);
626+
copy.setup(best_tips_per_measurement_index_buffer)->wait();
627+
628+
vecmem::data::vector_buffer<unsigned long long int>
629+
best_tips_per_measurement_insertion_mutex_buffer(n_measurements,
630+
mr.main);
631+
copy.setup(best_tips_per_measurement_insertion_mutex_buffer)->wait();
632+
633+
// NOTE: This memset assumes that an all-zero bit vector interpreted
634+
// as a floating point value has value zero, which is true for IEEE
635+
// 754 but might not be true for arbitrary float formats.
636+
copy.memset(best_tips_per_measurement_insertion_mutex_buffer, 0)
637+
->wait();
638+
639+
{
640+
vecmem::data::vector_buffer<scalar>
641+
best_tips_per_measurement_pval_buffer(
642+
config.max_num_tracks_per_measurement * n_measurements,
643+
mr.main);
644+
copy.setup(best_tips_per_measurement_pval_buffer)->wait();
645+
646+
const device::gather_best_tips_per_measurement_payload<
647+
typename detector_t::algebra_type>
648+
payload{tips_buffer,
649+
links_buffer,
650+
measurements_view,
651+
best_tips_per_measurement_insertion_mutex_buffer,
652+
best_tips_per_measurement_index_buffer,
653+
best_tips_per_measurement_pval_buffer,
654+
config.max_num_tracks_per_measurement};
655+
queue
656+
.submit([&](::sycl::handler& h) {
657+
h.parallel_for<
658+
kernels::gather_best_tips_per_measurement<kernel_t>>(
659+
calculate1DimNdRange(n_tips_total, 32),
660+
[payload](::sycl::nd_item<1> item) {
661+
device::gather_best_tips_per_measurement(
662+
details::global_index(item),
663+
details::barrier{item}, payload);
664+
});
665+
})
666+
.wait_and_throw();
667+
}
668+
669+
vecmem::data::vector_buffer<unsigned int> votes_per_tip_buffer(
670+
n_tips_total, mr.main);
671+
copy.setup(votes_per_tip_buffer)->wait();
672+
copy.memset(votes_per_tip_buffer, 0)->wait();
673+
674+
{
675+
const device::gather_measurement_votes_payload payload{
676+
best_tips_per_measurement_insertion_mutex_buffer,
677+
best_tips_per_measurement_index_buffer, votes_per_tip_buffer,
678+
config.max_num_tracks_per_measurement};
679+
680+
queue
681+
.submit([&](::sycl::handler& h) {
682+
h.parallel_for<kernels::gather_measurement_votes<kernel_t>>(
683+
calculate1DimNdRange(
684+
config.max_num_tracks_per_measurement *
685+
n_measurements,
686+
512),
687+
[payload](::sycl::nd_item<1> item) {
688+
device::gather_measurement_votes(
689+
details::global_index(item), payload);
690+
});
691+
})
692+
.wait_and_throw();
693+
}
694+
695+
tip_to_output_map =
696+
vecmem::data::vector_buffer<unsigned int>(n_tips_total, mr.main);
697+
copy.setup(tip_to_output_map)->wait();
698+
699+
{
700+
vecmem::data::vector_buffer<unsigned int> new_tip_length_buffer{
701+
n_tips_total, mr.main, vecmem::data::buffer_type::resizable};
702+
copy.setup(new_tip_length_buffer)->wait();
703+
704+
const device::update_tip_length_buffer_payload payload{
705+
tip_length_buffer, new_tip_length_buffer, votes_per_tip_buffer,
706+
tip_to_output_map, config.min_measurement_voting_fraction};
707+
708+
queue
709+
.submit([&](::sycl::handler& h) {
710+
h.parallel_for<kernels::update_tip_length_buffer<kernel_t>>(
711+
calculate1DimNdRange(n_tips_total, 512),
712+
[payload](::sycl::nd_item<1> item) {
713+
device::update_tip_length_buffer(
714+
details::global_index(item), payload);
715+
});
716+
})
717+
.wait_and_throw();
718+
719+
if (mr.host) {
720+
vecmem::async_size size =
721+
copy.get_size(tip_to_output_map, *(mr.host));
722+
// Here we could give control back to the caller, once our code
723+
// allows for it. (coroutines...)
724+
n_tips_total_filtered = size.get();
725+
} else {
726+
n_tips_total_filtered = copy.get_size(tip_to_output_map);
727+
}
728+
729+
tip_length_buffer = std::move(new_tip_length_buffer);
730+
}
731+
}
732+
733+
copy(tip_length_buffer, tips_length_host)->wait();
734+
tips_length_host.resize(n_tips_total_filtered);
735+
736+
unsigned int n_states;
737+
738+
if (config.run_mbf_smoother) {
739+
n_states = std::accumulate(tips_length_host.begin(),
740+
tips_length_host.end(), 0u);
741+
} else {
742+
n_states = 0;
743+
}
569744

570745
if (n_tips_total > 0) {
571746
copy(tip_length_buffer, tips_length_host)->wait();
@@ -575,8 +750,11 @@ combinatorial_kalman_filter(
575750
// Create track candidate buffer
576751
typename edm::track_container<typename detector_t::algebra_type>::buffer
577752
track_candidates_buffer{
578-
{tips_length_host, mr.main, mr.host}, {}, measurements_view};
753+
{tips_length_host, mr.main, mr.host},
754+
{n_states, mr.main, vecmem::data::buffer_type::resizable},
755+
measurements_view};
579756
copy.setup(track_candidates_buffer.tracks)->wait();
757+
copy.setup(track_candidates_buffer.states)->wait();
580758

581759
if (n_tips_total > 0) {
582760
queue
@@ -588,21 +766,26 @@ combinatorial_kalman_filter(
588766
tracks = typename edm::track_container<
589767
typename detector_t::algebra_type>::
590768
view(track_candidates_buffer),
769+
tip_to_output_map = vecmem::get_data(tip_to_output_map),
770+
jacobian_ptr = jacobian_ptr.get(),
591771
link_predicted_parameters =
592772
vecmem::get_data(link_predicted_parameter_buffer),
593773
link_filtered_parameters =
594774
vecmem::get_data(link_filtered_parameter_buffer)](
595775
::sycl::nd_item<1> item) {
596-
device::build_tracks(details::global_index(item),
597-
false && config.run_mbf_smoother,
598-
{.seeds_view = seeds,
599-
.links_view = links,
600-
.tips_view = tips,
601-
.tracks_view = tracks,
602-
.link_predicted_parameter_view =
603-
link_predicted_parameters,
604-
.link_filtered_parameter_view =
605-
link_filtered_parameters});
776+
device::build_tracks(
777+
details::global_index(item),
778+
config.run_mbf_smoother,
779+
{.seeds_view = seeds,
780+
.links_view = links,
781+
.tips_view = tips,
782+
.tracks_view = tracks,
783+
.tip_to_output_map = tip_to_output_map,
784+
.jacobian_ptr = jacobian_ptr,
785+
.link_predicted_parameter_view =
786+
link_predicted_parameters,
787+
.link_filtered_parameter_view =
788+
link_filtered_parameters});
606789
});
607790
})
608791
.wait_and_throw();

0 commit comments

Comments
 (0)