From 7c2d503fe77785b255a657d91dd1c9b69e4d343f Mon Sep 17 00:00:00 2001 From: Stephen Nicholas Swatman Date: Tue, 25 Mar 2025 16:00:28 +0100 Subject: [PATCH] Keep states ordered by quality in CKF This commit allows the `find_tracks` kernel in the CKF to keep the best measurements instead of keeping random ones, which will hopefully allow us to significantly reduce the branching factor of the CKF. The logic here works by employing a series of mutexes, allowing threads to lock a critical section and insert their own measurement, overwriting the previous ones. --- .../alpaka/src/finding/finding_algorithm.cpp | 106 ++-- .../src/finding/kernels/find_tracks.hpp | 14 +- .../traccc/finding/device/find_tracks.hpp | 30 +- .../finding/device/impl/find_tracks.ipp | 475 ++++++++++++++---- .../device/impl/propagate_to_next_surface.ipp | 12 +- device/cuda/src/finding/finding_algorithm.cu | 88 ++-- .../specializations/find_tracks_src.cuh | 11 +- device/sycl/src/finding/find_tracks.hpp | 131 +++-- 8 files changed, 619 insertions(+), 248 deletions(-) diff --git a/device/alpaka/src/finding/finding_algorithm.cpp b/device/alpaka/src/finding/finding_algorithm.cpp index ba3df56a29..0a360d2a3a 100644 --- a/device/alpaka/src/finding/finding_algorithm.cpp +++ b/device/alpaka/src/finding/finding_algorithm.cpp @@ -250,57 +250,69 @@ finding_algorithm::operator()( links_buffer = std::move(new_links_buffer); } - Idx blocksPerGrid = - (n_in_params + threadsPerBlock - 1) / threadsPerBlock; - auto workDiv = makeWorkDiv(blocksPerGrid, threadsPerBlock); + { + vecmem::data::vector_buffer tmp_links_buffer( + n_max_candidates, m_mr.main); + m_copy.setup(tmp_links_buffer)->ignore(); + bound_track_parameters_collection_types::buffer + tmp_params_buffer(n_max_candidates, m_mr.main); + m_copy.setup(tmp_params_buffer)->ignore(); - const unsigned int prev_link_idx = - step == 0 ? 0 : step_to_link_idx_map[step - 1]; - - assert(links_size == step_to_link_idx_map[step]); - - typedef device::find_tracks_payload> - PayloadType; - - auto bufHost_payload = - ::alpaka::allocBuf(devHost, 1u); - PayloadType* payload = ::alpaka::getPtrNative(bufHost_payload); - - new (payload) PayloadType{ - .det_data = det_view, - .measurements_view = measurements, - .in_params_view = vecmem::get_data(in_params_buffer), - .in_params_liveness_view = - vecmem::get_data(param_liveness_buffer), - .n_in_params = n_in_params, - .barcodes_view = vecmem::get_data(barcodes_buffer), - .upper_bounds_view = vecmem::get_data(upper_bounds_buffer), - .links_view = vecmem::get_data(links_buffer), - .prev_links_idx = prev_link_idx, - .curr_links_idx = step_to_link_idx_map[step], - .step = step, - .out_params_view = vecmem::get_data(updated_params_buffer), - .out_params_liveness_view = - vecmem::get_data(updated_liveness_buffer)}; - - auto bufAcc_payload = - ::alpaka::allocBuf(devAcc, 1u); - ::alpaka::memcpy(queue, bufAcc_payload, bufHost_payload); - ::alpaka::wait(queue); + Idx blocksPerGrid = + (n_in_params + threadsPerBlock - 1) / threadsPerBlock; + auto workDiv = makeWorkDiv(blocksPerGrid, threadsPerBlock); - ::alpaka::exec(queue, workDiv, - FindTracksKernel>{}, - m_cfg, ::alpaka::getPtrNative(bufAcc_payload)); - ::alpaka::wait(queue); + const unsigned int prev_link_idx = + step == 0 ? 0 : step_to_link_idx_map[step - 1]; - std::swap(in_params_buffer, updated_params_buffer); - std::swap(param_liveness_buffer, updated_liveness_buffer); + assert(links_size == step_to_link_idx_map[step]); - // Create a buffer for links - step_to_link_idx_map[step + 1] = m_copy.get_size(links_buffer); - n_candidates = - step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; - ::alpaka::wait(queue); + typedef device::find_tracks_payload> + PayloadType; + + auto bufHost_payload = + ::alpaka::allocBuf(devHost, 1u); + PayloadType* payload = ::alpaka::getPtrNative(bufHost_payload); + + new (payload) PayloadType{ + .det_data = det_view, + .measurements_view = measurements, + .in_params_view = vecmem::get_data(in_params_buffer), + .in_params_liveness_view = + vecmem::get_data(param_liveness_buffer), + .n_in_params = n_in_params, + .barcodes_view = vecmem::get_data(barcodes_buffer), + .upper_bounds_view = vecmem::get_data(upper_bounds_buffer), + .links_view = vecmem::get_data(links_buffer), + .prev_links_idx = prev_link_idx, + .curr_links_idx = step_to_link_idx_map[step], + .step = step, + .out_params_view = vecmem::get_data(updated_params_buffer), + .out_params_liveness_view = + vecmem::get_data(updated_liveness_buffer), + .tmp_params_view = tmp_params_buffer, + .tmp_links_view = tmp_links_buffer}; + + auto bufAcc_payload = + ::alpaka::allocBuf(devAcc, 1u); + ::alpaka::memcpy(queue, bufAcc_payload, bufHost_payload); + ::alpaka::wait(queue); + + ::alpaka::exec( + queue, workDiv, + FindTracksKernel>{}, m_cfg, + ::alpaka::getPtrNative(bufAcc_payload)); + ::alpaka::wait(queue); + + std::swap(in_params_buffer, updated_params_buffer); + std::swap(param_liveness_buffer, updated_liveness_buffer); + + // Create a buffer for links + step_to_link_idx_map[step + 1] = m_copy.get_size(links_buffer); + n_candidates = + step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; + ::alpaka::wait(queue); + } } if (n_candidates > 0) { diff --git a/device/alpaka/src/finding/kernels/find_tracks.hpp b/device/alpaka/src/finding/kernels/find_tracks.hpp index fd63f15562..4bf49ffedc 100644 --- a/device/alpaka/src/finding/kernels/find_tracks.hpp +++ b/device/alpaka/src/finding/kernels/find_tracks.hpp @@ -22,10 +22,15 @@ struct FindTracksKernel { TAcc const& acc, const finding_config& cfg, device::find_tracks_payload* payload) const { + auto& shared_num_out_params = + ::alpaka::declareSharedVar(acc); + auto& shared_out_offset = + ::alpaka::declareSharedVar(acc); auto& shared_candidates_size = ::alpaka::declareSharedVar(acc); - unsigned int* const s = ::alpaka::getDynSharedMem(acc); - unsigned int* shared_num_candidates = s; + unsigned long long int* const s = + ::alpaka::getDynSharedMem(acc); + unsigned long long int* shared_insertion_mutex = s; alpaka::barrier barrier(&acc); details::thread_id1 thread_id(acc); @@ -33,11 +38,12 @@ struct FindTracksKernel { unsigned int blockDimX = thread_id.getBlockDimX(); std::pair* shared_candidates = reinterpret_cast*>( - &shared_num_candidates[blockDimX]); + &shared_insertion_mutex[blockDimX]); device::find_tracks( thread_id, barrier, cfg, *payload, - {shared_num_candidates, shared_candidates, shared_candidates_size}); + {shared_num_out_params, shared_out_offset, shared_insertion_mutex, + shared_candidates, shared_candidates_size}); } }; diff --git a/device/common/include/traccc/finding/device/find_tracks.hpp b/device/common/include/traccc/finding/device/find_tracks.hpp index 698265b5d6..eb307a6ca9 100644 --- a/device/common/include/traccc/finding/device/find_tracks.hpp +++ b/device/common/include/traccc/finding/device/find_tracks.hpp @@ -23,6 +23,7 @@ #include // System include(s). +#include #include namespace traccc::device { @@ -98,15 +99,38 @@ struct find_tracks_payload { * @brief View object to the output track parameter liveness vector */ vecmem::data::vector_view out_params_liveness_view; + + /** + * @brief View object to the temporary track parameter vector + */ + bound_track_parameters_collection_types::view tmp_params_view; + + /** + * @brief View object to the temporary link vector + */ + vecmem::data::vector_view tmp_links_view; }; /// (Shared Event Data) Payload for the @c traccc::device::find_tracks function struct find_tracks_shared_payload { /** - * @brief Shared-memory vector with the number of measurements found per - * track + * @brief Shared-memory value indicating the final number of track + * parameters to write to permanent storage. + */ + unsigned int& shared_num_out_params; + + /** + * @brief Shared-memory value indicating the offset at which the block + * will write its parameters. + */ + unsigned int& shared_out_offset; + + /** + * @brief Shared-memory array with mutexes for the insertionof parameters. + * + * @note Length is always exactly the block size. */ - unsigned int* shared_num_candidates; + unsigned long long int* shared_insertion_mutex; /** * @brief Shared-memory vector of measurement candidats with ID and diff --git a/device/common/include/traccc/finding/device/impl/find_tracks.ipp b/device/common/include/traccc/finding/device/impl/find_tracks.ipp index 524f8dec8d..4e71a615c8 100644 --- a/device/common/include/traccc/finding/device/impl/find_tracks.ipp +++ b/device/common/include/traccc/finding/device/impl/find_tracks.ipp @@ -15,6 +15,7 @@ // compiler. This can be removed when intel/llvm#15443 makes it into a OneAPI // release. #include +#include #if defined(__INTEL_LLVM_COMPILER) && defined(SYCL_LANGUAGE_VERSION) #undef __CUDA_ARCH__ #endif @@ -32,6 +33,35 @@ namespace traccc::device { +namespace details { +/** + * @brief Encode the state of our parameter insertion mutex. + */ +TRACCC_HOST_DEVICE inline uint64_t encode_insertion_mutex(const bool locked, + const uint32_t size, + const float max) { + // Assert that the MSB of the size is zero + assert(size <= 0x7FFFFFFF); + + const uint32_t hi = size | (locked ? 0x80000000 : 0x0); + const uint32_t lo = std::bit_cast(max); + + return (static_cast(hi) << 32) | lo; +} + +/** + * @brief Decode the state of our parameter insertion mutex. + */ +TRACCC_HOST_DEVICE inline std::tuple +decode_insertion_mutex(const uint64_t val) { + const uint32_t hi = static_cast(val >> 32); + const uint32_t lo = val & 0xFFFFFFFF; + + return {static_cast(hi & 0x80000000), (hi & 0x7FFFFFFF), + std::bit_cast(lo)}; +} +} // namespace details + template TRACCC_HOST_DEVICE inline void find_tracks( @@ -39,18 +69,7 @@ TRACCC_HOST_DEVICE inline void find_tracks( const finding_config& cfg, const find_tracks_payload& payload, const find_tracks_shared_payload& shared_payload) { - /* - * Initialize the block-shared data; in particular, set the total size of - * the candidate buffer to zero, and then set the number of candidates for - * each parameter to zero. - */ - if (thread_id.getLocalThreadIdX() == 0) { - shared_payload.shared_candidates_size = 0; - } - - shared_payload.shared_num_candidates[thread_id.getLocalThreadIdX()] = 0; - - barrier.blockBarrier(); + const unsigned int in_param_id = thread_id.getGlobalThreadIdX(); /* * Initialize all of the device vectors from their vecmem views. @@ -63,16 +82,28 @@ TRACCC_HOST_DEVICE inline void find_tracks( vecmem::device_vector in_params_liveness( payload.in_params_liveness_view); vecmem::device_vector links(payload.links_view); - bound_track_parameters_collection_types::device out_params( - payload.out_params_view); - vecmem::device_vector out_params_liveness( - payload.out_params_liveness_view); + vecmem::device_vector tmp_links(payload.tmp_links_view); + bound_track_parameters_collection_types::device tmp_params( + payload.tmp_params_view); vecmem::device_vector barcodes( payload.barcodes_view); vecmem::device_vector upper_bounds( payload.upper_bounds_view); - const unsigned int in_param_id = thread_id.getGlobalThreadIdX(); + /* + * Initialize the block-shared data; in particular, set the total size of + * the candidate buffer to zero, and then set the number of candidates for + * each parameter to zero. + */ + if (thread_id.getLocalThreadIdX() == 0) { + shared_payload.shared_candidates_size = 0; + shared_payload.shared_num_out_params = 0; + } + + shared_payload.shared_insertion_mutex[thread_id.getLocalThreadIdX()] = + details::encode_insertion_mutex(false, 0, 0.f); + + barrier.blockBarrier(); /* * Step 1 of this kernel is to determine which indices belong to which @@ -168,6 +199,10 @@ TRACCC_HOST_DEVICE inline void find_tracks( barrier.blockBarrier(); + std::optional, + unsigned int, unsigned int>> + result = std::nullopt; + /* * The shared buffer is now full; each thread picks out zero or one of * the measurements and processes it. @@ -197,50 +232,246 @@ TRACCC_HOST_DEVICE inline void find_tracks( gain_matrix_updater>( trk_state, in_par); - const traccc::scalar chi2 = trk_state.filtered_chi2(); - - // The chi2 from Kalman update should be less than chi2_max - if (res == kalman_fitter_status::SUCCESS && - trk_state.filtered_chi2() < cfg.chi2_max) { - // Add measurement candidates to link - const unsigned int l_pos = links.bulk_append_implicit(1); - - assert(trk_state.filtered_chi2() >= 0.f); - - if (payload.step == 0) { - links.at(l_pos) = { - .step = payload.step, - .previous_candidate_idx = owner_global_thread_id, - .meas_idx = meas_idx, - .seed_idx = owner_global_thread_id, - .n_skipped = 0, - .chi2 = chi2}; - } else { - const unsigned int prev_link_idx = - payload.prev_links_idx + owner_global_thread_id; - - const candidate_link& prev_link = links.at(prev_link_idx); - - assert(payload.step == prev_link.step + 1); - - links.at(l_pos) = {.step = payload.step, - .previous_candidate_idx = prev_link_idx, - .meas_idx = meas_idx, - .seed_idx = prev_link.seed_idx, - .n_skipped = prev_link.n_skipped, - .chi2 = chi2}; + /* + * The $\chi^2$ value from the Kalman update should be less than + * `chi2_max`, and the fit should have succeeded. If both + * conditions are true, we emplace the state, the measurement + * index, and the thread ID into an optional value. + * + * NOTE: Using the optional value here allows us to remove the + * depth of if-statements which is important for code quality but, + * more importantly, allows us to more easily use block-wide + * synchronization primitives. + */ + if (const traccc::scalar chi2 = trk_state.filtered_chi2(); + res == kalman_fitter_status::SUCCESS && chi2 < cfg.chi2_max) { + result.emplace(std::move(trk_state), meas_idx, + owner_local_thread_id); + } + } + + /* + * Now comes the stage in which we add the parameters to the temporary + * array, in such a way that we keep the best ones. This loop has a + * barrier to ensure both thread safety and forward progress. + * + * NOTE: This has to be a loop because the software is set up such + * that only one thread can write to the array of one input + * parameter per loop cycle. Thus, the loop is here to resolve any + * contention. + */ + while (barrier.blockOr(result.has_value())) { + /* + * Threads which have no parameter stored (either because they + * never had one or because they already deposited) do not have to + * do anything. + */ + if (result.has_value()) { + /* + * First, we reconstruct some necessary information from the + * data that we stored previously. + */ + const unsigned int meas_idx = std::get<1>(*result); + const unsigned int owner_local_thread_id = std::get<2>(*result); + const unsigned int owner_global_thread_id = + owner_local_thread_id + + thread_id.getBlockDimX() * thread_id.getBlockIdX(); + const float chi2 = std::get<0>(*result).filtered_chi2(); + assert(chi2 >= 0.f); + unsigned long long int* mutex_ptr = + &shared_payload + .shared_insertion_mutex[owner_local_thread_id]; + + /* + * The current thread will attempt to get a lock on the + * output array for the input parameter ID which it is now + * holding. If it manages to do so, the `index` variable will + * be set to a value smaller than or equal to the maximum + * number of elements; otherwise, it will be set to + * `UINT_MAX`. + */ + unsigned int index = std::numeric_limits::max(); + unsigned int new_size = 0; + unsigned long long int desired = 0; + + /* + * We fetch and decode whatever the mutex state is at the + * current time. The mutex is a 64-bit integer containing the + * following: + * + * [00:31] A 32-bit IEEE 754 floating point number that equals + * the highest $\chi^2$ value among parameters + * currently stored. + * [32:62] A 31-bit unsigned integer representing the number + * of parameters currently stored. + * [63:63] A boolean that, if true, indicates that a thread is + * currently operating on the array guarded. + */ + unsigned long long int assumed = *mutex_ptr; + auto [locked, size, max] = + details::decode_insertion_mutex(assumed); + + /* + * If the array is already full _and_ our parameter has a + * higher $\chi^2$ value than any of the elements in the + * array, we can discard the current track state. + */ + if (size >= cfg.max_num_branches_per_surface && chi2 >= max) { + result.reset(); } - // Increase the number of candidates (or branches) per input - // parameter - vecmem::device_atomic_ref( - shared_payload.shared_num_candidates[owner_local_thread_id]) - .fetch_add(1u); + /* + * If we still have a track after the previous check, we will + * try to add this. We can only do this if the mutex is not + * locked. + */ + if (result.has_value() && !locked) { + new_size = size < cfg.max_num_branches_per_surface + ? size + 1 + : size; + desired = + details::encode_insertion_mutex(true, new_size, max); + + /* + * Attempt to CAS the mutex with the same value as before + * but with the lock bit switched. If this succeeds (e.g. + * the return value is as we assumed) then we have succes + * fully locked and we set the `index` variable, which + * indicates that we have the lock. + */ + if (vecmem::device_atomic_ref< + unsigned long long, + vecmem::device_address_space::local>(*mutex_ptr) + .compare_exchange_strong(assumed, desired)) { + index = size; + } + } - out_params.at(l_pos - payload.curr_links_idx) = - trk_state.filtered(); - out_params_liveness.at(l_pos - payload.curr_links_idx) = 1u; + /* + * If `index` is not `UINT32_MAX`, we are in the green to + * write to the parameter array! + */ + if (index != std::numeric_limits::max()) { + assert(result.has_value()); + assert(index <= cfg.max_num_branches_per_surface); + + /* + * We will now proceed to find the index in the temporary + * array that we will write to. There are two distinct + * cases: + * + * 1. If `index` is the maximum branching value, then the + * array is already full, and we need to replace the + * worst existing parameter. + * 2. If `index` is less than the maximum branching value, + * we can trivially insert the value at index. + */ + unsigned int l_pos = + std::numeric_limits::max(); + const unsigned int p_offset = + owner_global_thread_id * + cfg.max_num_branches_per_surface; + float new_max; + + if (index == cfg.max_num_branches_per_surface) { + /* + * Handle the case in which we need to replace a + * value; find the worst existing parameter and then + * replace it. Also keep track of what the new maximum + * $\chi^2$ value will be. + */ + float highest = 0.f; + + for (unsigned int i = 0; + i < cfg.max_num_branches_per_surface; ++i) { + float old_chi2 = tmp_links.at(p_offset + i).chi2; + + if (old_chi2 > highest) { + highest = old_chi2; + l_pos = i; + } + } + + assert(l_pos != + std::numeric_limits::max()); + + new_max = chi2; + + for (unsigned int i = 0; + i < cfg.max_num_branches_per_surface; ++i) { + float old_chi2 = tmp_links.at(p_offset + i).chi2; + + if (i != l_pos && old_chi2 > new_max) { + new_max = old_chi2; + } + + assert(old_chi2 <= + tmp_links.at(p_offset + l_pos).chi2); + } + + assert(chi2 <= new_max); + } else { + l_pos = index; + new_max = std::max(chi2, max); + } + + assert(l_pos < cfg.max_num_branches_per_surface); + + /* + * Now, simply insert the temporary link at the found + * position. Different cases for step 0 and other steps. + */ + if (payload.step == 0) { + tmp_links.at(p_offset + l_pos) = { + .step = payload.step, + .previous_candidate_idx = owner_global_thread_id, + .meas_idx = meas_idx, + .seed_idx = owner_global_thread_id, + .n_skipped = 0, + .chi2 = chi2}; + } else { + const unsigned int prev_link_idx = + payload.prev_links_idx + owner_global_thread_id; + + const candidate_link& prev_link = + links.at(prev_link_idx); + + assert(payload.step == prev_link.step + 1); + + tmp_links.at(p_offset + l_pos) = { + .step = payload.step, + .previous_candidate_idx = prev_link_idx, + .meas_idx = meas_idx, + .seed_idx = prev_link.seed_idx, + .n_skipped = prev_link.n_skipped, + .chi2 = chi2}; + } + + tmp_params.at(p_offset + l_pos) = + std::get<0>(*result).filtered(); + + /* + * Reset the temporary state storage, as this is no longer + * necessary; this implies that this thread will not try + * to insert anything in the next loop iteration. + */ + result.reset(); + + /* + * Release the lock using another atomic CAS operation. + * Because nobody should be writing to this value, it + * should always succeed! + */ + [[maybe_unused]] bool cas_result = + vecmem::device_atomic_ref< + unsigned long long, + vecmem::device_address_space::local>(*mutex_ptr) + .compare_exchange_strong( + desired, details::encode_insertion_mutex( + false, new_size, new_max)); + + assert(cas_result); + } } } @@ -267,44 +498,106 @@ TRACCC_HOST_DEVICE inline void find_tracks( } /* - * Part three of the kernel inserts holes for parameters which did not - * match any measurements. + * Part three of the kernel inserts holes from the. + * + * NOTE: A synchronization point here is unnecessary, as it is implicit in + * the condition of the while-loop above. */ - if (in_param_id < payload.n_in_params && - in_params_liveness.at(in_param_id) > 0u && - shared_payload.shared_num_candidates[thread_id.getLocalThreadIdX()] == - 0u) { - // Add measurement candidates to link - const unsigned int l_pos = links.bulk_append_implicit(1); - - if (payload.step == 0) { - links.at(l_pos) = { - .step = payload.step, - .previous_candidate_idx = in_param_id, - .meas_idx = std::numeric_limits::max(), - .seed_idx = in_param_id, - .n_skipped = 1, - .chi2 = std::numeric_limits::max()}; - } else { - const unsigned int prev_link_idx = - payload.prev_links_idx + in_param_id; - const candidate_link& prev_link = links.at(prev_link_idx); + bool in_param_is_live = in_param_id < payload.n_in_params && + in_params_liveness.at(in_param_id) > 0u; - assert(payload.step == prev_link.step + 1); + unsigned int local_out_offset = 0; + unsigned int local_num_params = 0; - links.at(l_pos) = { - .step = payload.step, - .previous_candidate_idx = prev_link_idx, - .meas_idx = std::numeric_limits::max(), - .seed_idx = prev_link.seed_idx, - .n_skipped = prev_link.n_skipped + 1, - .chi2 = std::numeric_limits::max()}; - } + /* + * Compute the offset at which this block will write, as well as the index + * at which this block will write. + */ + if (in_param_is_live) { + local_num_params = std::get<1>(details::decode_insertion_mutex( + shared_payload + .shared_insertion_mutex[thread_id.getLocalThreadIdX()])); + /* + * NOTE: We always create at least one state, because we also create + * hole states for nodes which don't find any good compatible + * measurements. + */ + local_out_offset = + vecmem::device_atomic_ref( + shared_payload.shared_num_out_params) + .fetch_add(std::max(1u, local_num_params)); + } + + barrier.blockBarrier(); - out_params.at(l_pos - payload.curr_links_idx) = - in_params.at(in_param_id); - out_params_liveness.at(l_pos - payload.curr_links_idx) = 1u; + if (thread_id.getLocalThreadIdX() == 0) { + shared_payload.shared_out_offset = + links.bulk_append_implicit(shared_payload.shared_num_out_params); + } + + barrier.blockBarrier(); + + /* + * Finally, transfer the links and parameters from temporary storage + * to the permanent storage in global memory, remembering to create hole + * states even for threads which have zero states. + */ + bound_track_parameters_collection_types::device out_params( + payload.out_params_view); + vecmem::device_vector out_params_liveness( + payload.out_params_liveness_view); + + if (in_param_is_live) { + if (local_num_params == 0) { + const unsigned int out_offset = + shared_payload.shared_out_offset + local_out_offset; + + if (payload.step == 0) { + links.at(out_offset) = { + .step = payload.step, + .previous_candidate_idx = in_param_id, + .meas_idx = std::numeric_limits::max(), + .seed_idx = in_param_id, + .n_skipped = 1, + .chi2 = std::numeric_limits::max()}; + } else { + const unsigned int prev_link_idx = + payload.prev_links_idx + in_param_id; + + const candidate_link& prev_link = links.at(prev_link_idx); + + assert(payload.step == prev_link.step + 1); + + links.at(out_offset) = { + .step = payload.step, + .previous_candidate_idx = prev_link_idx, + .meas_idx = std::numeric_limits::max(), + .seed_idx = prev_link.seed_idx, + .n_skipped = prev_link.n_skipped + 1, + .chi2 = std::numeric_limits::max()}; + } + + out_params.at(out_offset - payload.curr_links_idx) = + in_params.at(in_param_id); + out_params_liveness.at(out_offset - payload.curr_links_idx) = 1u; + } else { + for (unsigned int i = 0; i < local_num_params; ++i) { + const unsigned int in_offset = + thread_id.getGlobalThreadIdX() * + cfg.max_num_branches_per_surface + + i; + const unsigned int out_offset = + shared_payload.shared_out_offset + local_out_offset + i; + + out_params.at(out_offset - payload.curr_links_idx) = + tmp_params.at(in_offset); + out_params_liveness.at(out_offset - payload.curr_links_idx) = + 1u; + links.at(out_offset) = tmp_links.at(in_offset); + } + } } } diff --git a/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp b/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp index 54548cbe46..16d0b88ac5 100644 --- a/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp +++ b/device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp @@ -27,9 +27,15 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface( // Theta id vecmem::device_vector param_ids(payload.param_ids_view); + vecmem::device_vector params_liveness( + payload.params_liveness_view); const unsigned int param_id = param_ids.at(globalIndex); + if (params_liveness.at(param_id) == 0u) { + return; + } + // Number of tracks per seed vecmem::device_vector n_tracks_per_seed( payload.n_tracks_per_seed_view); @@ -46,8 +52,6 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface( n_tracks_per_seed.at(orig_param_id)); const unsigned int s_pos = num_tracks_per_seed.fetch_add(1); - vecmem::device_vector params_liveness( - payload.params_liveness_view); if (s_pos >= cfg.max_num_branches_per_seed) { params_liveness[param_id] = 0u; @@ -70,10 +74,6 @@ TRACCC_HOST_DEVICE inline void propagate_to_next_surface( // Parameters bound_track_parameters_collection_types::device params(payload.params_view); - if (params_liveness.at(param_id) == 0u) { - return; - } - // Input bound track parameter const bound_track_parameters<> in_par = params.at(param_id); diff --git a/device/cuda/src/finding/finding_algorithm.cu b/device/cuda/src/finding/finding_algorithm.cu index 3255b4e9f0..0c48930d17 100644 --- a/device/cuda/src/finding/finding_algorithm.cu +++ b/device/cuda/src/finding/finding_algorithm.cu @@ -257,48 +257,60 @@ finding_algorithm::operator()( links_buffer = std::move(new_links_buffer); } - const unsigned int nThreads = m_warp_size * 2; - const unsigned int nBlocks = - (n_in_params + nThreads - 1) / nThreads; - - const unsigned int prev_link_idx = - step == 0 ? 0 : step_to_link_idx_map[step - 1]; - - assert(links_size == step_to_link_idx_map[step]); - - kernels::find_tracks> - <<), - stream>>>( - m_cfg, - device::find_tracks_payload>{ - .det_data = det_view, - .measurements_view = measurements, - .in_params_view = in_params_buffer, - .in_params_liveness_view = param_liveness_buffer, - .n_in_params = n_in_params, - .barcodes_view = barcodes_buffer, - .upper_bounds_view = upper_bounds_buffer, - .links_view = links_buffer, - .prev_links_idx = prev_link_idx, - .curr_links_idx = step_to_link_idx_map[step], - .step = step, - .out_params_view = updated_params_buffer, - .out_params_liveness_view = updated_liveness_buffer}); - TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); + { + const unsigned int nThreads = m_warp_size * 2; + const unsigned int nBlocks = + (n_in_params + nThreads - 1) / nThreads; + + const unsigned int prev_link_idx = + step == 0 ? 0 : step_to_link_idx_map[step - 1]; + + assert(links_size == step_to_link_idx_map[step]); + + vecmem::data::vector_buffer tmp_links_buffer( + n_max_candidates, m_mr.main); + m_copy.setup(tmp_links_buffer)->ignore(); + bound_track_parameters_collection_types::buffer + tmp_params_buffer(n_max_candidates, m_mr.main); + m_copy.setup(tmp_params_buffer)->ignore(); + + kernels::find_tracks> + <<), + stream>>>( + m_cfg, + device::find_tracks_payload< + std::decay_t>{ + .det_data = det_view, + .measurements_view = measurements, + .in_params_view = in_params_buffer, + .in_params_liveness_view = param_liveness_buffer, + .n_in_params = n_in_params, + .barcodes_view = barcodes_buffer, + .upper_bounds_view = upper_bounds_buffer, + .links_view = links_buffer, + .prev_links_idx = prev_link_idx, + .curr_links_idx = step_to_link_idx_map[step], + .step = step, + .out_params_view = updated_params_buffer, + .out_params_liveness_view = updated_liveness_buffer, + .tmp_params_view = tmp_params_buffer, + .tmp_links_view = tmp_links_buffer}); + TRACCC_CUDA_ERROR_CHECK(cudaGetLastError()); - std::swap(in_params_buffer, updated_params_buffer); - std::swap(param_liveness_buffer, updated_liveness_buffer); + std::swap(in_params_buffer, updated_params_buffer); + std::swap(param_liveness_buffer, updated_liveness_buffer); - m_stream.synchronize(); + m_stream.synchronize(); - step_to_link_idx_map[step + 1] = m_copy.get_size(links_buffer); - n_candidates = - step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; + step_to_link_idx_map[step + 1] = m_copy.get_size(links_buffer); + n_candidates = + step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; - m_stream.synchronize(); + m_stream.synchronize(); + } } if (n_candidates > 0) { diff --git a/device/cuda/src/finding/kernels/specializations/find_tracks_src.cuh b/device/cuda/src/finding/kernels/specializations/find_tracks_src.cuh index 60b4b83a62..c35fff7e16 100644 --- a/device/cuda/src/finding/kernels/specializations/find_tracks_src.cuh +++ b/device/cuda/src/finding/kernels/specializations/find_tracks_src.cuh @@ -22,18 +22,21 @@ namespace traccc::cuda::kernels { template __global__ void find_tracks(const finding_config cfg, device::find_tracks_payload payload) { + __shared__ unsigned int shared_num_out_params; + __shared__ unsigned int shared_out_offset; __shared__ unsigned int shared_candidates_size; - extern __shared__ unsigned int s[]; - unsigned int* shared_num_candidates = s; + extern __shared__ unsigned long long int s[]; + unsigned long long int* shared_insertion_mutex = s; std::pair* shared_candidates = reinterpret_cast*>( - &shared_num_candidates[blockDim.x]); + &shared_insertion_mutex[blockDim.x]); cuda::barrier barrier; details::thread_id1 thread_id; device::find_tracks( thread_id, barrier, cfg, payload, - {shared_num_candidates, shared_candidates, shared_candidates_size}); + {shared_num_out_params, shared_out_offset, shared_insertion_mutex, + shared_candidates, shared_candidates_size}); } } // namespace traccc::cuda::kernels diff --git a/device/sycl/src/finding/find_tracks.hpp b/device/sycl/src/finding/find_tracks.hpp index 6a1e68024f..fd5b354b22 100644 --- a/device/sycl/src/finding/find_tracks.hpp +++ b/device/sycl/src/finding/find_tracks.hpp @@ -260,63 +260,84 @@ track_candidate_container_types::buffer find_tracks( links_buffer = std::move(new_links_buffer); } - // The number of threads to use per block in the track finding. - static const unsigned int nFindTracksThreads = 64; - - // Submit the kernel to the queue. - queue - .submit([&](::sycl::handler& h) { - // Allocate shared memory for the kernel. - vecmem::sycl::local_accessor - shared_num_candidates(nFindTracksThreads, h); - vecmem::sycl::local_accessor< - std::pair> - shared_candidates(2 * nFindTracksThreads, h); - vecmem::sycl::local_accessor - shared_candidates_size(1, h); - - // Launch the kernel. - h.parallel_for>( - calculate1DimNdRange(n_in_params, nFindTracksThreads), - [config, det, measurements, - in_params = vecmem::get_data(in_params_buffer), - param_liveness = vecmem::get_data(param_liveness_buffer), - n_in_params, barcodes = vecmem::get_data(barcodes_buffer), - upper_bounds = vecmem::get_data(upper_bounds_buffer), - links_view = vecmem::get_data(links_buffer), - prev_links_idx = - step == 0 ? 0 : step_to_link_idx_map[step - 1], - curr_links_idx = step_to_link_idx_map[step], step, - updated_params = vecmem::get_data(updated_params_buffer), - updated_liveness = - vecmem::get_data(updated_liveness_buffer), - shared_candidates_size, shared_num_candidates, - shared_candidates](::sycl::nd_item<1> item) { - // SYCL wrappers used in the algorithm. - const details::barrier barrier{item}; - const details::thread_id thread_id{item}; - - // Call the device function to find tracks. - device::find_tracks< - std::decay_t>( - thread_id, barrier, config, - {det, measurements, in_params, param_liveness, - n_in_params, barcodes, upper_bounds, links_view, - prev_links_idx, curr_links_idx, step, - updated_params, updated_liveness}, - {&(shared_num_candidates[0]), - &(shared_candidates[0]), - shared_candidates_size[0]}); - }); - }) - .wait_and_throw(); + { + vecmem::data::vector_buffer tmp_links_buffer( + n_max_candidates, mr.main); + copy.setup(tmp_links_buffer)->ignore(); + bound_track_parameters_collection_types::buffer tmp_params_buffer( + n_max_candidates, mr.main); + copy.setup(tmp_params_buffer)->ignore(); + + // The number of threads to use per block in the track finding. + static const unsigned int nFindTracksThreads = 64; + + // Submit the kernel to the queue. + queue + .submit([&](::sycl::handler& h) { + // Allocate shared memory for the kernel. + vecmem::sycl::local_accessor + shared_insertion_mutex(nFindTracksThreads, h); + vecmem::sycl::local_accessor< + std::pair> + shared_candidates(2 * nFindTracksThreads, h); + vecmem::sycl::local_accessor + shared_candidates_size(1, h); + vecmem::sycl::local_accessor + shared_num_out_params(1, h); + vecmem::sycl::local_accessor + shared_out_offset(1, h); + + // Launch the kernel. + h.parallel_for>( + calculate1DimNdRange(n_in_params, nFindTracksThreads), + [config, det, measurements, + in_params = vecmem::get_data(in_params_buffer), + param_liveness = + vecmem::get_data(param_liveness_buffer), + n_in_params, + barcodes = vecmem::get_data(barcodes_buffer), + upper_bounds = vecmem::get_data(upper_bounds_buffer), + links_view = vecmem::get_data(links_buffer), + prev_links_idx = + step == 0 ? 0 : step_to_link_idx_map[step - 1], + curr_links_idx = step_to_link_idx_map[step], step, + updated_params = + vecmem::get_data(updated_params_buffer), + updated_liveness = + vecmem::get_data(updated_liveness_buffer), + tmp_params = vecmem::get_data(tmp_params_buffer), + tmp_links = vecmem::get_data(tmp_links_buffer), + shared_insertion_mutex, shared_candidates, + shared_candidates_size, shared_num_out_params, + shared_out_offset](::sycl::nd_item<1> item) { + // SYCL wrappers used in the algorithm. + const details::barrier barrier{item}; + const details::thread_id thread_id{item}; + + // Call the device function to find tracks. + device::find_tracks>( + thread_id, barrier, config, + {det, measurements, in_params, param_liveness, + n_in_params, barcodes, upper_bounds, + links_view, prev_links_idx, curr_links_idx, + step, updated_params, updated_liveness, + tmp_params, tmp_links}, + {shared_num_out_params[0], shared_out_offset[0], + &(shared_insertion_mutex[0]), + &(shared_candidates[0]), + shared_candidates_size[0]}); + }); + }) + .wait_and_throw(); - std::swap(in_params_buffer, updated_params_buffer); - std::swap(param_liveness_buffer, updated_liveness_buffer); + std::swap(in_params_buffer, updated_params_buffer); + std::swap(param_liveness_buffer, updated_liveness_buffer); - step_to_link_idx_map[step + 1] = copy.get_size(links_buffer); - n_candidates = - step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; + step_to_link_idx_map[step + 1] = copy.get_size(links_buffer); + n_candidates = + step_to_link_idx_map[step + 1] - step_to_link_idx_map[step]; + } if (n_candidates > 0) { /*****************************************************************