-
Notifications
You must be signed in to change notification settings - Fork 225
[experimental] Use kernel foundry DBSCAN get_core optimizations #3592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2978a51
5036a47
ac227a1
49d822c
5fa381f
44bcadc
8ecfdbd
ecec4dd
a5477ec
9ec2061
61fd8b7
5a9d63a
0fd4cbc
32d81c3
6ffbcca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,126 +66,160 @@ | |
| /// of the updated arrays(cores and neighbours) for reading and writing | ||
| template <typename Float, bool use_weights> | ||
| struct get_core_wide_kernel { | ||
| static auto run(sycl::queue& queue, | ||
| const pr::ndview<Float, 2>& data, | ||
| const pr::ndview<Float, 2>& weights, | ||
| pr::ndview<std::int32_t, 1>& cores, | ||
| pr::ndview<Float, 1>& neighbours, | ||
| Float epsilon, | ||
| std::int64_t min_observations, | ||
| const bk::event_vector& deps) { | ||
| const std::int64_t local_row_count = data.get_dimension(0); | ||
| const std::int64_t column_count = data.get_dimension(1); | ||
|
|
||
| ONEDAL_ASSERT(local_row_count > 0); | ||
| ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == local_row_count); | ||
| ONEDAL_ASSERT(!use_weights || weights.get_dimension(1) == 1); | ||
| ONEDAL_ASSERT(cores.get_dimension(0) == local_row_count); | ||
| ONEDAL_ASSERT(neighbours.get_dimension(0) == local_row_count); | ||
|
|
||
| const Float* data_ptr = data.get_data(); | ||
| const Float* weights_ptr = weights.get_data(); | ||
| std::int32_t* cores_ptr = cores.get_mutable_data(); | ||
| Float* neighbours_ptr = neighbours.get_mutable_data(); | ||
|
|
||
| ONEDAL_ASSERT(local_row_count <= std::numeric_limits<std::uint32_t>::max()); | ||
| ONEDAL_ASSERT(column_count <= std::numeric_limits<std::uint32_t>::max()); | ||
|
|
||
| auto event = queue.submit([&](sycl::handler& cgh) { | ||
| cgh.depends_on(deps); | ||
|
|
||
| const std::int64_t wg_size = get_recommended_wg_size(queue, column_count); | ||
| const std::int64_t block_split_size = | ||
| get_recommended_check_block_size(queue, column_count, wg_size); | ||
| cgh.parallel_for( | ||
|
|
||
| cgh.parallel_for<get_core_wide_kernel<Float, use_weights>>( | ||
| bk::make_multiple_nd_range_2d({ wg_size, local_row_count }, { wg_size, 1 }), | ||
| [=](sycl::nd_item<2> item) { | ||
| auto sg = item.get_sub_group(); | ||
| const std::uint32_t sg_id = sg.get_group_id()[0]; | ||
| if (sg_id > 0) | ||
| sycl::sub_group sg = item.get_sub_group(); | ||
|
|
||
| if (sg.get_group_id()[0] != 0) | ||
| return; | ||
|
|
||
| const std::uint32_t wg_id = item.get_global_id(1); | ||
| if (wg_id >= local_row_count) | ||
| const std::uint32_t row_count = static_cast<std::uint32_t>(local_row_count); | ||
| const std::uint32_t col_count = static_cast<std::uint32_t>(column_count); | ||
|
|
||
| const std::uint32_t row_i = static_cast<std::uint32_t>(item.get_global_id(1)); | ||
| if (row_i >= row_count) | ||
| return; | ||
|
|
||
| const std::uint32_t local_id = sg.get_local_id(); | ||
| const std::uint32_t local_size = sg.get_local_range()[0]; | ||
| const std::uint32_t lane = sg.get_local_id()[0]; | ||
| const std::uint32_t sg_size = sg.get_local_range()[0]; | ||
|
|
||
| const Float min_obs_f = static_cast<Float>(min_observations); | ||
|
|
||
| std::uint32_t block_split = static_cast<std::uint32_t>(block_split_size); | ||
| block_split = block_split ? block_split : 1u; | ||
|
|
||
| const std::uint32_t base_i = row_i * col_count; | ||
| const Float* const xi = data_ptr + base_i; | ||
|
|
||
| Float count = neighbours_ptr[row_i]; | ||
|
|
||
| for (std::uint32_t j = 0; j < row_count; ++j) { | ||
| const Float* const xj = data_ptr + (j * col_count); | ||
|
|
||
| Float count = neighbours_ptr[wg_id]; | ||
| for (std::int64_t j = 0; j < local_row_count; j++) { | ||
| Float sum = Float(0); | ||
| std::int64_t count_iter = 0; | ||
| for (std::int64_t i = local_id; i < column_count; i += local_size) { | ||
| count_iter++; | ||
| Float val = | ||
| data_ptr[wg_id * column_count + i] - data_ptr[j * column_count + i]; | ||
| sum += val * val; | ||
|
|
||
| if (count_iter % block_split_size == 0 && | ||
| local_size * count_iter <= column_count) { | ||
| Float distance_check = | ||
| sycl::reduce_over_group(sg, | ||
| sum, | ||
| sycl::ext::oneapi::plus<Float>()); | ||
| if (distance_check > epsilon) { | ||
| break; | ||
|
|
||
| bool pruned = false; | ||
| std::uint32_t ticks = block_split; | ||
| std::uint32_t iter = 0; | ||
|
|
||
| for (std::uint32_t k = lane; k < col_count; k += sg_size) { | ||
| ++iter; | ||
|
|
||
| const Float v = xi[k] - xj[k]; | ||
| sum = sycl::fma(v, v, sum); | ||
|
|
||
| if (--ticks == 0) { | ||
| ticks = block_split; | ||
|
|
||
| if (sg_size * iter <= col_count) { | ||
| const Float partial = | ||
| sycl::reduce_over_group(sg, sum, sycl::plus<Float>()); | ||
| if (partial > epsilon) { | ||
| pruned = true; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Float distance = | ||
| sycl::reduce_over_group(sg, sum, sycl::ext::oneapi::plus<Float>()); | ||
| if (distance <= epsilon) { | ||
|
|
||
| if (pruned) { | ||
| continue; | ||
| } | ||
|
|
||
| const Float dist = sycl::reduce_over_group(sg, sum, sycl::plus<Float>()); | ||
|
|
||
| if (dist <= epsilon) { | ||
| count += use_weights ? weights_ptr[j] : Float(1); | ||
| if (local_id == 0) { | ||
| neighbours_ptr[wg_id] = count; | ||
| } | ||
| if (count >= min_observations && !use_weights) { | ||
| if (local_id == 0) { | ||
| cores_ptr[wg_id] = Float(1); | ||
|
|
||
| if (!use_weights && count >= min_obs_f) { | ||
| if (lane == 0) { | ||
| neighbours_ptr[row_i] = count; | ||
| cores_ptr[row_i] = std::int32_t(1); | ||
| } | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| if (neighbours_ptr[wg_id] >= min_observations) { | ||
| cores_ptr[wg_id] = Float(1); | ||
|
|
||
| if (lane == 0) { | ||
| neighbours_ptr[row_i] = count; | ||
|
|
||
| if (count >= min_obs_f) { | ||
| cores_ptr[row_i] = std::int32_t(1); | ||
| } | ||
| } | ||
| }); | ||
| }); | ||
| return event; | ||
| } | ||
| }; | ||
|
|
||
| /// A struct that finds the core points without subgroups | ||
| /// it is effective only on narrow cases. The column count of narrow cases < 4. | ||
| /// | ||
| /// @tparam Float Floating-point type used to perform computations | ||
| /// @tparam use_weights Bool type used to check that weights are enabled | ||
| /// | ||
| /// @param[in] queue The SYCL queue | ||
| /// @param[in] data The input data of size `row_count` x `column_count` | ||
| /// @param[in] weights The input weights of size `row_count` x `1` | ||
| /// @param[in] cores The current cores of size `row_count` x `1` | ||
| /// @param[in] neighbours The current neighbours of size `row_count` x `1` | ||
| /// it contains the counter of neighbours for each point | ||
| /// @param[in] epsilon The input parameter epsilon | ||
| /// @param[in] min_observations The input parameter min_observation | ||
| /// @param[in] deps Events indicating availability of the `data` for reading or writing | ||
| /// | ||
| /// @return A SYCL event indicating the availability | ||
| /// of the updated arrays(cores and neighbours) for reading and writing | ||
| template <typename Float, bool use_weights> | ||
| struct get_core_narrow_kernel { | ||
| static auto run(sycl::queue& queue, | ||
| const pr::ndview<Float, 2>& data, | ||
| const pr::ndview<Float, 2>& weights, | ||
| pr::ndview<std::int32_t, 1>& cores, | ||
| pr::ndview<Float, 1>& neighbours, | ||
| Float epsilon, | ||
| std::int64_t min_observations, | ||
| const bk::event_vector& deps) { | ||
| const std::int64_t local_row_count = data.get_dimension(0); | ||
| const std::int64_t column_count = data.get_dimension(1); | ||
|
|
||
| ONEDAL_ASSERT(local_row_count > 0); | ||
| ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == local_row_count); | ||
| ONEDAL_ASSERT(!use_weights || weights.get_dimension(1) == 1); | ||
|
|
||
| ONEDAL_ASSERT(cores.get_dimension(0) == local_row_count); | ||
|
|
@@ -236,105 +270,140 @@ | |
| /// @param[in] queue The SYCL queue | ||
| /// @param[in] data The input data of size `row_count` x `column_count` | ||
| /// @param[in] data_replace The input data from another rank of size `row_count` x `column_count` | ||
| /// @param[in] weights The input weights of size `row_count` x `1` | ||
| /// @param[in] cores The current cores of size `row_count` x `1` | ||
| /// @param[in] neighbours The current neighbours of size `row_count` x `1` | ||
| /// it contains the counter of neighbours for each point | ||
| /// @param[in] epsilon The input parameter epsilon | ||
| /// @param[in] min_observations The input parameter min_observation | ||
| /// @param[in] deps Events indicating availability of the `data` for reading or writing | ||
| /// | ||
| /// @return A SYCL event indicating the availability | ||
| /// of the updated arrays(cores and neighbours) for reading and writing | ||
| template <typename Float, bool use_weights> | ||
| struct get_core_send_recv_replace_wide_kernel { | ||
| static auto run(sycl::queue& queue, | ||
| const pr::ndview<Float, 2>& data, | ||
| const pr::ndview<Float, 2>& data_replace, | ||
| const pr::ndview<Float, 2>& weights, | ||
| pr::ndview<std::int32_t, 1>& cores, | ||
| pr::ndview<Float, 1>& neighbours, | ||
| Float epsilon, | ||
| std::int64_t min_observations, | ||
| const bk::event_vector& deps) { | ||
| const std::int64_t local_row_count = data.get_dimension(0); | ||
| const std::int64_t row_count_replace = data_replace.get_dimension(0); | ||
| const std::int64_t column_count = data.get_dimension(1); | ||
|
|
||
| ONEDAL_ASSERT(local_row_count > 0); | ||
| ONEDAL_ASSERT(row_count_replace > 0); | ||
| ONEDAL_ASSERT(!use_weights || weights.get_dimension(0) == local_row_count); | ||
| ONEDAL_ASSERT(!use_weights || weights.get_dimension(1) == 1); | ||
|
|
||
| ONEDAL_ASSERT(cores.get_dimension(0) == local_row_count); | ||
|
|
||
| const Float* data_ptr = data.get_data(); | ||
| const Float* data_replace_ptr = data_replace.get_data(); | ||
| const Float* weights_ptr = weights.get_data(); | ||
| std::int32_t* cores_ptr = cores.get_mutable_data(); | ||
| Float* neighbours_ptr = neighbours.get_mutable_data(); | ||
|
|
||
| auto event = queue.submit([&](sycl::handler& cgh) { | ||
| cgh.depends_on(deps); | ||
| const std::int64_t wg_size = get_recommended_wg_size(queue, column_count); | ||
| const std::int64_t block_split_size = | ||
| get_recommended_check_block_size(queue, column_count, wg_size); | ||
| cgh.parallel_for( | ||
| bk::make_multiple_nd_range_2d({ wg_size, local_row_count }, { wg_size, 1 }), | ||
| [=](sycl::nd_item<2> item) { | ||
| auto sg = item.get_sub_group(); | ||
| const std::uint32_t sg_id = sg.get_group_id()[0]; | ||
| if (sg_id > 0) | ||
| sycl::sub_group sg = item.get_sub_group(); | ||
|
|
||
| if (sg.get_group_id()[0] != 0) | ||
| return; | ||
| const std::uint32_t wg_id = item.get_global_id(1); | ||
| if (wg_id >= local_row_count) | ||
|
|
||
| const std::uint32_t row_count_local = | ||
| static_cast<std::uint32_t>(local_row_count); | ||
| const std::uint32_t row_count_repl = | ||
| static_cast<std::uint32_t>(row_count_replace); | ||
| const std::uint32_t col_count = static_cast<std::uint32_t>(column_count); | ||
|
Comment on lines
+324
to
+328
|
||
|
|
||
| const std::uint32_t row_i = static_cast<std::uint32_t>(item.get_global_id(1)); | ||
| if (row_i >= row_count_local) | ||
| return; | ||
| const std::uint32_t local_id = sg.get_local_id(); | ||
| const std::uint32_t local_size = sg.get_local_range()[0]; | ||
|
|
||
| Float count = neighbours_ptr[wg_id]; | ||
| for (std::int64_t j = 0; j < row_count_replace; j++) { | ||
| const std::uint32_t lane = sg.get_local_id()[0]; | ||
| const std::uint32_t sg_size = sg.get_local_range()[0]; | ||
|
|
||
| const Float min_obs_f = static_cast<Float>(min_observations); | ||
|
|
||
| std::uint32_t block_split = static_cast<std::uint32_t>(block_split_size); | ||
| block_split = block_split ? block_split : 1u; | ||
|
|
||
| const std::uint32_t base_i = row_i * col_count; | ||
| const Float* const xi = data_ptr + base_i; | ||
|
|
||
| Float count = neighbours_ptr[row_i]; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this is a |
||
|
|
||
| for (std::uint32_t j = 0; j < row_count_repl; ++j) { | ||
| const Float* const xj = data_replace_ptr + (j * col_count); | ||
|
|
||
|
Comment on lines
+342
to
+349
|
||
| Float sum = Float(0); | ||
| std::int64_t count_iter = 0; | ||
| for (std::int64_t i = local_id; i < column_count; i += local_size) { | ||
| count_iter++; | ||
| Float val = data_ptr[wg_id * column_count + i] - | ||
| data_replace_ptr[j * column_count + i]; | ||
| sum += val * val; | ||
| if (count_iter % block_split_size == 0 && | ||
| local_size * count_iter <= column_count) { | ||
| Float distance_check = | ||
| sycl::reduce_over_group(sg, | ||
| sum, | ||
| sycl::ext::oneapi::plus<Float>()); | ||
| if (distance_check > epsilon) { | ||
| break; | ||
|
|
||
| bool pruned = false; | ||
| std::uint32_t ticks = block_split; | ||
| std::uint32_t iter = 0; | ||
|
|
||
| for (std::uint32_t k = lane; k < col_count; k += sg_size) { | ||
| ++iter; | ||
|
|
||
| const Float v = xi[k] - xj[k]; | ||
| sum = sycl::fma(v, v, sum); | ||
|
|
||
| if (--ticks == 0) { | ||
| ticks = block_split; | ||
|
|
||
| if (sg_size * iter <= col_count) { | ||
| const Float partial = | ||
| sycl::reduce_over_group(sg, sum, sycl::plus<Float>()); | ||
| if (partial > epsilon) { | ||
| pruned = true; | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Float distance = | ||
| sycl::reduce_over_group(sg, sum, sycl::ext::oneapi::plus<Float>()); | ||
| if (distance <= epsilon) { | ||
|
|
||
| if (pruned) { | ||
| continue; | ||
| } | ||
|
|
||
| const Float dist = sycl::reduce_over_group(sg, sum, sycl::plus<Float>()); | ||
|
|
||
| if (dist <= epsilon) { | ||
| count += use_weights ? weights_ptr[j] : Float(1); | ||
| if (local_id == 0) { | ||
| neighbours_ptr[wg_id] = count; | ||
| } | ||
| if (count >= min_observations && !use_weights) { | ||
| if (local_id == 0) { | ||
| cores_ptr[wg_id] = Float(1); | ||
|
|
||
| if (!use_weights && count >= min_obs_f) { | ||
| if (lane == 0) { | ||
| neighbours_ptr[row_i] = count; | ||
| cores_ptr[row_i] = std::int32_t(1); | ||
| } | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| if (neighbours_ptr[wg_id] >= min_observations) { | ||
| cores_ptr[wg_id] = Float(1); | ||
|
|
||
| if (lane == 0) { | ||
| neighbours_ptr[row_i] = count; | ||
|
|
||
| if (count >= min_obs_f) { | ||
| cores_ptr[row_i] = std::int32_t(1); | ||
| } | ||
| } | ||
| }); | ||
| }); | ||
| return event; | ||
| } | ||
| }; | ||
|
|
||
| /// A struct that finds the core points without subgroups | ||
| /// on sendrecv_replaced data. It means that this function tries to | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base_i/xjpointer offsets are computed using 32-bit multiplication (row_i * col_countandj * col_count). Even thoughrow_count/col_countare individually asserted to fit inuint32_t, their product can still overflowuint32_t, leading to incorrect pointer arithmetic and potential out-of-bounds reads on the device. Consider usingstd::uint64_t/std::size_tfor offsets (or add an assert thatrow_count * col_countfits) before doing pointer arithmetic.