11/* * TRACCC library, part of the ACTS project (R&D line)
22 *
3- * (c) 2024-2025 CERN for the benefit of the ACTS project
3+ * (c) 2024-2026 CERN for the benefit of the ACTS project
44 *
55 * Mozilla Public License Version 2.0
66 */
2929#include " traccc/finding/device/fill_finding_duplicate_removal_sort_keys.hpp"
3030#include " traccc/finding/device/fill_finding_propagation_sort_keys.hpp"
3131#include " traccc/finding/device/find_tracks.hpp"
32+ #include " traccc/finding/device/gather_best_tips_per_measurement.hpp"
33+ #include " traccc/finding/device/gather_measurement_votes.hpp"
3234#include " traccc/finding/device/propagate_to_next_surface.hpp"
3335#include " traccc/finding/device/remove_duplicates.hpp"
36+ #include " traccc/finding/device/update_tip_length_buffer.hpp"
3437#include " traccc/finding/finding_config.hpp"
3538#include " traccc/utils/memory_resource.hpp"
3639#include " traccc/utils/projections.hpp"
@@ -58,6 +61,14 @@ struct fill_finding_propagation_sort_keys {};
5861template <typename T>
5962struct propagate_to_next_surface {};
6063template <typename T>
64+ struct gather_best_tips_per_measurement {};
65+
66+ template <typename T>
67+ struct gather_measurement_votes {};
68+
69+ template <typename T>
70+ struct update_tip_length_buffer {};
71+ template <typename T>
6172struct build_tracks {};
6273} // namespace kernels
6374
@@ -184,7 +195,7 @@ combinatorial_kalman_filter(
184195 * finding, we need some space to store the intermediate Jacobians
185196 * and parameters. Allocate that space here.
186197 */
187- if (false && config.run_mbf_smoother ) {
198+ if (config.run_mbf_smoother ) {
188199 jacobian_ptr = vecmem::make_unique_alloc<
189200 bound_matrix<typename detector_t ::algebra_type>[]>(
190201 mr.main , link_buffer_capacity);
@@ -273,6 +284,41 @@ combinatorial_kalman_filter(
273284 copy (links_buffer, new_links_buffer)->wait ();
274285
275286 links_buffer = std::move (new_links_buffer);
287+
288+ if (config.run_mbf_smoother ) {
289+ vecmem::unique_alloc_ptr<
290+ bound_matrix<typename detector_t ::algebra_type>[]>
291+ new_jacobian_ptr = vecmem::make_unique_alloc<
292+ bound_matrix<typename detector_t ::algebra_type>[]>(
293+ mr.main , link_buffer_capacity);
294+ bound_track_parameters_collection_types::buffer
295+ new_link_predicted_parameter_buffer{link_buffer_capacity,
296+ mr.main };
297+ bound_track_parameters_collection_types::buffer
298+ new_link_filtered_parameter_buffer{link_buffer_capacity,
299+ mr.main };
300+
301+ copy (
302+ vecmem::data::vector_view<
303+ bound_matrix<typename detector_t ::algebra_type>>{
304+ links_size, jacobian_ptr.get ()},
305+ vecmem::data::vector_view<
306+ bound_matrix<typename detector_t ::algebra_type>>{
307+ link_buffer_capacity, new_jacobian_ptr.get ()})
308+ ->wait ();
309+ copy (link_predicted_parameter_buffer,
310+ new_link_predicted_parameter_buffer)
311+ ->wait ();
312+ copy (link_filtered_parameter_buffer,
313+ new_link_filtered_parameter_buffer)
314+ ->wait ();
315+
316+ jacobian_ptr = std::move (new_jacobian_ptr);
317+ link_predicted_parameter_buffer =
318+ std::move (new_link_predicted_parameter_buffer);
319+ link_filtered_parameter_buffer =
320+ std::move (new_link_filtered_parameter_buffer);
321+ }
276322 }
277323
278324 {
@@ -503,7 +549,7 @@ combinatorial_kalman_filter(
503549 *****************************************************************/
504550
505551 {
506- if (false && config.run_mbf_smoother ) {
552+ if (config.run_mbf_smoother ) {
507553 tmp_jacobian_ptr = vecmem::make_unique_alloc<
508554 bound_matrix<typename detector_t ::algebra_type>[]>(
509555 mr.main , n_candidates);
@@ -565,7 +611,136 @@ combinatorial_kalman_filter(
565611 // Get the number of tips
566612 const unsigned int n_tips_total = copy.get_size (tips_buffer);
567613
568- std::vector<unsigned int > tips_length_host;
614+ vecmem::vector<unsigned int > tips_length_host (mr.host );
615+ vecmem::data::vector_buffer<unsigned int > tip_to_output_map;
616+
617+ unsigned int n_tips_total_filtered = n_tips_total;
618+
619+ if (n_tips_total > 0 && config.max_num_tracks_per_measurement > 0 ) {
620+ // TODO: DOCS
621+
622+ vecmem::data::vector_buffer<unsigned int >
623+ best_tips_per_measurement_index_buffer (
624+ config.max_num_tracks_per_measurement * n_measurements,
625+ mr.main );
626+ copy.setup (best_tips_per_measurement_index_buffer)->wait ();
627+
628+ vecmem::data::vector_buffer<unsigned long long int >
629+ best_tips_per_measurement_insertion_mutex_buffer (n_measurements,
630+ mr.main );
631+ copy.setup (best_tips_per_measurement_insertion_mutex_buffer)->wait ();
632+
633+ // NOTE: This memset assumes that an all-zero bit vector interpreted
634+ // as a floating point value has value zero, which is true for IEEE
635+ // 754 but might not be true for arbitrary float formats.
636+ copy.memset (best_tips_per_measurement_insertion_mutex_buffer, 0 )
637+ ->wait ();
638+
639+ {
640+ vecmem::data::vector_buffer<scalar>
641+ best_tips_per_measurement_pval_buffer (
642+ config.max_num_tracks_per_measurement * n_measurements,
643+ mr.main );
644+ copy.setup (best_tips_per_measurement_pval_buffer)->wait ();
645+
646+ const device::gather_best_tips_per_measurement_payload<
647+ typename detector_t ::algebra_type>
648+ payload{tips_buffer,
649+ links_buffer,
650+ measurements_view,
651+ best_tips_per_measurement_insertion_mutex_buffer,
652+ best_tips_per_measurement_index_buffer,
653+ best_tips_per_measurement_pval_buffer,
654+ config.max_num_tracks_per_measurement };
655+ queue
656+ .submit ([&](::sycl::handler& h) {
657+ h.parallel_for <
658+ kernels::gather_best_tips_per_measurement<kernel_t >>(
659+ calculate1DimNdRange (n_tips_total, 32 ),
660+ [payload](::sycl::nd_item<1 > item) {
661+ device::gather_best_tips_per_measurement (
662+ details::global_index (item),
663+ details::barrier{item}, payload);
664+ });
665+ })
666+ .wait_and_throw ();
667+ }
668+
669+ vecmem::data::vector_buffer<unsigned int > votes_per_tip_buffer (
670+ n_tips_total, mr.main );
671+ copy.setup (votes_per_tip_buffer)->wait ();
672+ copy.memset (votes_per_tip_buffer, 0 )->wait ();
673+
674+ {
675+ const device::gather_measurement_votes_payload payload{
676+ best_tips_per_measurement_insertion_mutex_buffer,
677+ best_tips_per_measurement_index_buffer, votes_per_tip_buffer,
678+ config.max_num_tracks_per_measurement };
679+
680+ queue
681+ .submit ([&](::sycl::handler& h) {
682+ h.parallel_for <kernels::gather_measurement_votes<kernel_t >>(
683+ calculate1DimNdRange (
684+ config.max_num_tracks_per_measurement *
685+ n_measurements,
686+ 512 ),
687+ [payload](::sycl::nd_item<1 > item) {
688+ device::gather_measurement_votes (
689+ details::global_index (item), payload);
690+ });
691+ })
692+ .wait_and_throw ();
693+ }
694+
695+ tip_to_output_map =
696+ vecmem::data::vector_buffer<unsigned int >(n_tips_total, mr.main );
697+ copy.setup (tip_to_output_map)->wait ();
698+
699+ {
700+ vecmem::data::vector_buffer<unsigned int > new_tip_length_buffer{
701+ n_tips_total, mr.main , vecmem::data::buffer_type::resizable};
702+ copy.setup (new_tip_length_buffer)->wait ();
703+
704+ const device::update_tip_length_buffer_payload payload{
705+ tip_length_buffer, new_tip_length_buffer, votes_per_tip_buffer,
706+ tip_to_output_map, config.min_measurement_voting_fraction };
707+
708+ queue
709+ .submit ([&](::sycl::handler& h) {
710+ h.parallel_for <kernels::update_tip_length_buffer<kernel_t >>(
711+ calculate1DimNdRange (n_tips_total, 512 ),
712+ [payload](::sycl::nd_item<1 > item) {
713+ device::update_tip_length_buffer (
714+ details::global_index (item), payload);
715+ });
716+ })
717+ .wait_and_throw ();
718+
719+ if (mr.host ) {
720+ vecmem::async_size size =
721+ copy.get_size (tip_to_output_map, *(mr.host ));
722+ // Here we could give control back to the caller, once our code
723+ // allows for it. (coroutines...)
724+ n_tips_total_filtered = size.get ();
725+ } else {
726+ n_tips_total_filtered = copy.get_size (tip_to_output_map);
727+ }
728+
729+ tip_length_buffer = std::move (new_tip_length_buffer);
730+ }
731+ }
732+
733+ copy (tip_length_buffer, tips_length_host)->wait ();
734+ tips_length_host.resize (n_tips_total_filtered);
735+
736+ unsigned int n_states;
737+
738+ if (config.run_mbf_smoother ) {
739+ n_states = std::accumulate (tips_length_host.begin (),
740+ tips_length_host.end (), 0u );
741+ } else {
742+ n_states = 0 ;
743+ }
569744
570745 if (n_tips_total > 0 ) {
571746 copy (tip_length_buffer, tips_length_host)->wait ();
@@ -575,8 +750,11 @@ combinatorial_kalman_filter(
575750 // Create track candidate buffer
576751 typename edm::track_container<typename detector_t ::algebra_type>::buffer
577752 track_candidates_buffer{
578- {tips_length_host, mr.main , mr.host }, {}, measurements_view};
753+ {tips_length_host, mr.main , mr.host },
754+ {n_states, mr.main , vecmem::data::buffer_type::resizable},
755+ measurements_view};
579756 copy.setup (track_candidates_buffer.tracks )->wait ();
757+ copy.setup (track_candidates_buffer.states )->wait ();
580758
581759 if (n_tips_total > 0 ) {
582760 queue
@@ -588,21 +766,26 @@ combinatorial_kalman_filter(
588766 tracks = typename edm::track_container<
589767 typename detector_t ::algebra_type>::
590768 view (track_candidates_buffer),
769+ tip_to_output_map = vecmem::get_data (tip_to_output_map),
770+ jacobian_ptr = jacobian_ptr.get (),
591771 link_predicted_parameters =
592772 vecmem::get_data (link_predicted_parameter_buffer),
593773 link_filtered_parameters =
594774 vecmem::get_data (link_filtered_parameter_buffer)](
595775 ::sycl::nd_item<1 > item) {
596- device::build_tracks (details::global_index (item),
597- false && config.run_mbf_smoother ,
598- {.seeds_view = seeds,
599- .links_view = links,
600- .tips_view = tips,
601- .tracks_view = tracks,
602- .link_predicted_parameter_view =
603- link_predicted_parameters,
604- .link_filtered_parameter_view =
605- link_filtered_parameters});
776+ device::build_tracks (
777+ details::global_index (item),
778+ config.run_mbf_smoother ,
779+ {.seeds_view = seeds,
780+ .links_view = links,
781+ .tips_view = tips,
782+ .tracks_view = tracks,
783+ .tip_to_output_map = tip_to_output_map,
784+ .jacobian_ptr = jacobian_ptr,
785+ .link_predicted_parameter_view =
786+ link_predicted_parameters,
787+ .link_filtered_parameter_view =
788+ link_filtered_parameters});
606789 });
607790 })
608791 .wait_and_throw ();
0 commit comments