diff --git a/cub/benchmarks/bench/radix_sort/keys.cu b/cub/benchmarks/bench/radix_sort/keys.cu index f3d95d1642a..4c308f55086 100644 --- a/cub/benchmarks/bench/radix_sort/keys.cu +++ b/cub/benchmarks/bench/radix_sort/keys.cu @@ -14,82 +14,42 @@ template void radix_sort_keys(nvbench::state& state, nvbench::type_list) { - using offset_t = cub::detail::choose_offset_t; - - constexpr cub::SortOrder sort_order = cub::SortOrder::Ascending; - constexpr bool is_overwrite_ok = false; - using key_t = T; - using value_t = cub::NullType; - - if constexpr (!fits_in_default_shared_memory()) + using key_t = T; + using value_t = cub::NullType; + if constexpr (!fits_in_default_shared_memory()) { return; } - constexpr int begin_bit = 0; - constexpr int end_bit = sizeof(key_t) * 8; - // Retrieve axis parameters const auto elements = static_cast(state.get_int64("Elements{io}")); const bit_entropy entropy = str_to_entropy(state.get_string("Entropy")); thrust::device_vector buffer_1 = generate(elements, entropy); - thrust::device_vector buffer_2(elements); + thrust::device_vector buffer_2(elements, thrust::no_init); - key_t* d_buffer_1 = thrust::raw_pointer_cast(buffer_1.data()); - key_t* d_buffer_2 = thrust::raw_pointer_cast(buffer_2.data()); - - cub::DoubleBuffer d_keys(d_buffer_1, d_buffer_2); - cub::DoubleBuffer d_values; + const key_t* d_buffer_1 = thrust::raw_pointer_cast(buffer_1.data()); + key_t* d_buffer_2 = thrust::raw_pointer_cast(buffer_2.data()); // Enable throughput calculations and add "Size" column to results. state.add_element_count(elements); state.add_global_memory_reads(elements, "Size"); state.add_global_memory_writes(elements); - // Allocate temporary storage: - std::size_t temp_size{}; - - cub::detail::radix_sort::dispatch( - nullptr, - temp_size, - d_keys, - d_values, - static_cast(elements), - begin_bit, - end_bit, - is_overwrite_ok, - 0 /* stream */ -#if !TUNE_BASE - , - cub::detail::identity_decomposer_t{}, - policy_selector{} -#endif // !TUNE_BASE - ); - - thrust::device_vector temp(temp_size, thrust::no_init); - auto* temp_storage = thrust::raw_pointer_cast(temp.data()); - + auto mr = cub::detail::device_memory_resource{}; state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { - cub::DoubleBuffer keys = d_keys; - cub::DoubleBuffer values = d_values; - - cub::detail::radix_sort::dispatch( - temp_storage, - temp_size, - keys, - values, - static_cast(elements), - begin_bit, - end_bit, - is_overwrite_ok, - launch.get_stream() + cub::DeviceRadixSort::SortKeys( + d_buffer_1, + d_buffer_2, + static_cast(elements), + cuda::std::execution::env{ + ::cuda::stream_ref{launch.get_stream().get_stream()}, + mr, #if !TUNE_BASE , - cub::detail::identity_decomposer_t{}, - policy_selector{} + cuda::execution::__tune(policy_selector{}) #endif // !TUNE_BASE - ); + }); }); } diff --git a/cub/benchmarks/bench/radix_sort/pairs.cu b/cub/benchmarks/bench/radix_sort/pairs.cu index 74d3af8bfe3..498fbaec412 100644 --- a/cub/benchmarks/bench/radix_sort/pairs.cu +++ b/cub/benchmarks/bench/radix_sort/pairs.cu @@ -16,37 +16,26 @@ template void radix_sort_values(nvbench::state& state, nvbench::type_list) { - using offset_t = cub::detail::choose_offset_t; - - constexpr cub::SortOrder sort_order = cub::SortOrder::Ascending; - constexpr bool is_overwrite_ok = false; - using key_t = KeyT; - using value_t = ValueT; - - if constexpr (!fits_in_default_shared_memory()) + using key_t = KeyT; + using value_t = ValueT; + if constexpr (!fits_in_default_shared_memory()) { return; } - constexpr int begin_bit = 0; - constexpr int end_bit = sizeof(key_t) * 8; - // Retrieve axis parameters const auto elements = static_cast(state.get_int64("Elements{io}")); const bit_entropy entropy = str_to_entropy(state.get_string("Entropy")); - thrust::device_vector keys_buffer_1 = generate(elements, entropy); - thrust::device_vector values_buffer_1 = generate(elements); - thrust::device_vector keys_buffer_2(elements); - thrust::device_vector values_buffer_2(elements); + thrust::device_vector keys_in = generate(elements, entropy); + thrust::device_vector keys_out(elements, thrust::no_init); + thrust::device_vector values_in = generate(elements); + thrust::device_vector values_out(elements, thrust::no_init); - key_t* d_keys_buffer_1 = thrust::raw_pointer_cast(keys_buffer_1.data()); - key_t* d_keys_buffer_2 = thrust::raw_pointer_cast(keys_buffer_2.data()); - value_t* d_values_buffer_1 = thrust::raw_pointer_cast(values_buffer_1.data()); - value_t* d_values_buffer_2 = thrust::raw_pointer_cast(values_buffer_2.data()); - - cub::DoubleBuffer d_keys(d_keys_buffer_1, d_keys_buffer_2); - cub::DoubleBuffer d_values(d_values_buffer_1, d_values_buffer_2); + const key_t* d_keys_in = thrust::raw_pointer_cast(keys_in.data()); + key_t* d_keys_out = thrust::raw_pointer_cast(keys_out.data()); + const value_t* d_values_in = thrust::raw_pointer_cast(values_in.data()); + value_t* d_values_out = thrust::raw_pointer_cast(values_out.data()); // Enable throughput calculations and add "Size" column to results. state.add_element_count(elements); @@ -55,48 +44,22 @@ void radix_sort_values(nvbench::state& state, nvbench::type_list(elements); state.add_global_memory_writes(elements); - // Allocate temporary storage: - std::size_t temp_size{}; - cub::detail::radix_sort::dispatch( - nullptr, - temp_size, - d_keys, - d_values, - static_cast(elements), - begin_bit, - end_bit, - is_overwrite_ok, - 0 /* stream */ -#if !TUNE_BASE - , - cub::detail::identity_decomposer_t{}, - policy_selector{} -#endif // !TUNE_BASE - ); - - thrust::device_vector temp(temp_size, thrust::no_init); - auto* temp_storage = thrust::raw_pointer_cast(temp.data()); - + auto mr = cub::detail::device_memory_resource{}; state.exec(nvbench::exec_tag::gpu | nvbench::exec_tag::no_batch, [&](nvbench::launch& launch) { - cub::DoubleBuffer keys = d_keys; - cub::DoubleBuffer values = d_values; - - cub::detail::radix_sort::dispatch( - temp_storage, - temp_size, - keys, - values, - static_cast(elements), - begin_bit, - end_bit, - is_overwrite_ok, - launch.get_stream() + cub::DeviceRadixSort::SortPairs( + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + static_cast(elements), + cuda::std::execution::env{ + ::cuda::stream_ref{launch.get_stream().get_stream()}, + mr, #if !TUNE_BASE , - cub::detail::identity_decomposer_t{}, - policy_selector{} + cuda::execution::__tune2(policy_selector{}) #endif // !TUNE_BASE - ); + }); }); } diff --git a/cub/benchmarks/bench/radix_sort/policy_selector.h b/cub/benchmarks/bench/radix_sort/policy_selector.h index 288898ccac1..1e94b418f3c 100644 --- a/cub/benchmarks/bench/radix_sort/policy_selector.h +++ b/cub/benchmarks/bench/radix_sort/policy_selector.h @@ -76,13 +76,33 @@ struct policy_selector template constexpr std::size_t max_onesweep_temp_storage_size() { - using portion_offset = int; - using onesweep_policy = typename policy_hub_t::policy_t::OnesweepPolicy; + using portion_offset = int; + + constexpr auto active_policy = policy_selector{}(cuda::arch_id{}); + + constexpr auto onesweep = active_policy.onesweep; + using onesweep_policy_t = AgentRadixSortOnesweepPolicy< + 0, + 0, + void, + onesweep.rank_num_parts, + onesweep.rank_algorith, + onesweep.scan_algorithm, + onesweep.store_algorithm, + onesweep.radix_bits, + NoScaling>; + using agent_radix_sort_onesweep_t = - cub::AgentRadixSortOnesweep; + cub::AgentRadixSortOnesweep; - using hist_policy = typename policy_hub_t::policy_t::HistogramPolicy; - using hist_agent = cub::AgentRadixSortHistogram; + constexpr auto histogram = active_policy.histogram; + using histogram_policy_t = + AgentRadixSortHistogramPolicy; + using hist_agent = cub::AgentRadixSortHistogram; return cuda::std::max(sizeof(typename agent_radix_sort_onesweep_t::TempStorage), sizeof(typename hist_agent::TempStorage)); @@ -91,10 +111,11 @@ constexpr std::size_t max_onesweep_temp_storage_size() template constexpr std::size_t max_temp_storage_size() { - using policy_t = typename policy_hub_t::policy_t; + using offset_t = cub::detail::choose_offset_t; + using policy_t = typename policy_hub_t::policy_t; static_assert(policy_t::ONESWEEP); - return max_onesweep_temp_storage_size(); + return max_onesweep_temp_storage_size(); } template diff --git a/cub/benchmarks/bench/transform/common.h b/cub/benchmarks/bench/transform/common.h index 90547fc2dd6..08cb5df557b 100644 --- a/cub/benchmarks/bench/transform/common.h +++ b/cub/benchmarks/bench/transform/common.h @@ -26,8 +26,7 @@ #include #if !TUNE_BASE -// TODO(bgruber): can we get by without the base class? -struct policy_selector : cub::detail::transform::tuning +struct policy_selector { _CCCL_API constexpr auto operator()(cuda::arch_id) const -> cub::detail::transform::transform_policy { diff --git a/cub/cub/device/device_radix_sort.cuh b/cub/cub/device/device_radix_sort.cuh index da2b513c49f..8f1b94ae233 100644 --- a/cub/cub/device/device_radix_sort.cuh +++ b/cub/cub/device/device_radix_sort.cuh @@ -32,24 +32,6 @@ CUB_NAMESPACE_BEGIN -namespace detail::radix_sort -{ -struct get_tuning_query_t -{}; - -template -struct tuning -{ - [[nodiscard]] _CCCL_TRIVIAL_API constexpr auto query(const get_tuning_query_t&) const noexcept -> Derived - { - return static_cast(*this); - } -}; - -struct default_tuning : tuning -{}; -} // namespace detail::radix_sort - //! @rst //! DeviceRadixSort provides device-wide, parallel operations for //! computing a radix sort across a sequence of data items residing @@ -429,8 +411,8 @@ public: template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortPairs( const KeyT* d_keys_in, KeyT* d_keys_out, @@ -445,14 +427,28 @@ public: using offset_t = detail::choose_offset_t; - // Dispatch with environment - handles all boilerplate DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, false, stream); - }); + // Dispatch with environment - handles all boilerplate + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + false, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -963,8 +959,8 @@ public: template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortPairs( DoubleBuffer& d_keys, DoubleBuffer& d_values, @@ -977,10 +973,24 @@ public: using offset_t = detail::choose_offset_t; - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, true, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + true, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -1494,8 +1504,8 @@ public: template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortPairsDescending( const KeyT* d_keys_in, KeyT* d_keys_out, @@ -1513,10 +1523,24 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values(const_cast(d_values_in), d_values_out); - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, false, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + false, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -2031,8 +2055,8 @@ public: template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortPairsDescending( DoubleBuffer& d_keys, DoubleBuffer& d_values, @@ -2045,10 +2069,24 @@ public: using offset_t = detail::choose_offset_t; - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, true, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + true, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -2540,8 +2578,8 @@ public: //! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortKeys( const KeyT* d_keys_in, KeyT* d_keys_out, @@ -2558,10 +2596,24 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, false, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + false, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -3036,8 +3088,8 @@ public: //! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortKeys( DoubleBuffer& d_keys, NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, EnvT env = {}) { @@ -3047,10 +3099,24 @@ public: DoubleBuffer d_values; - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, true, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + true, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -3512,8 +3578,8 @@ public: //! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortKeysDescending( const KeyT* d_keys_in, KeyT* d_keys_out, @@ -3530,10 +3596,24 @@ public: DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); DoubleBuffer d_values; - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, false, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + false, + stream, + {}, + policy_selector_t{}); + }); } //! @rst @@ -4006,8 +4086,8 @@ public: //! **[optional]** Execution environment. Default is ``cuda::std::execution::env{}``. template , - typename ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> + typename EnvT = ::cuda::std::execution::env<>, + ::cuda::std::enable_if_t<::cuda::std::is_integral_v, int> = 0> [[nodiscard]] CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t SortKeysDescending( DoubleBuffer& d_keys, NumItemsT num_items, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, EnvT env = {}) { @@ -4017,10 +4097,24 @@ public: DoubleBuffer d_values; - return detail::dispatch_with_env(env, [&]([[maybe_unused]] auto tuning, void* storage, size_t& bytes, auto stream) { - return detail::radix_sort::dispatch( - storage, bytes, d_keys, d_values, static_cast(num_items), begin_bit, end_bit, true, stream); - }); + return detail::dispatch_with_env( + env, [&]([[maybe_unused]] auto tuning_env, void* storage, size_t& bytes, auto stream) { + using default_policy_selector_t = detail::radix_sort::policy_selector_from_types; + using policy_selector_t = ::cuda::std::execution:: + __query_result_or_t; + return detail::radix_sort::dispatch( + storage, + bytes, + d_keys, + d_values, + static_cast(num_items), + begin_bit, + end_bit, + true, + stream, + {}, + policy_selector_t{}); + }); } //! @rst diff --git a/cub/cub/device/device_reduce.cuh b/cub/cub/device/device_reduce.cuh index 628601149eb..6dac9dfaa9d 100644 --- a/cub/cub/device/device_reduce.cuh +++ b/cub/cub/device/device_reduce.cuh @@ -59,24 +59,6 @@ inline constexpr bool is_non_deterministic_v = namespace reduce { -struct get_tuning_query_t -{}; - -template -struct tuning -{ - [[nodiscard]] _CCCL_NODEBUG_API constexpr auto query(const get_tuning_query_t&) const noexcept -> Derived - { - return static_cast(*this); - } -}; - -struct default_rfa_tuning : tuning -{ - template - using fn = detail::rfa::policy_selector_from_types; -}; - template struct unzip_and_write_arg_extremum_op { @@ -145,10 +127,9 @@ private: using offset_t = detail::choose_offset_t; using accum_t = ::cuda::std:: __accumulator_t>, T>; - using reduce_tuning_t = ::cuda::std::execution::__query_result_or_t< - TuningEnvT, - detail::reduce::get_tuning_query_t, - detail::reduce::policy_selector_from_types>; + using default_policy_selector = detail::reduce::policy_selector_from_types; + using policy_selector = + ::cuda::std::execution::__query_result_or_t; return detail::reduce::dispatch( d_temp_storage, @@ -160,7 +141,7 @@ private: init, stream, transform_op, - reduce_tuning_t{}); + policy_selector{}); } template ; - - using reduce_tuning_t = ::cuda::std::execution:: - __query_result_or_t; - - using accum_t = ::cuda::std:: + using accum_t = ::cuda::std:: __accumulator_t>, T>; - using policy_t = typename reduce_tuning_t::template fn; + using default_policy_selector = detail::rfa::policy_selector_from_types; + using policy_selector = + ::cuda::std::execution::__query_result_or_t; - return detail::rfa::dispatch( - d_temp_storage, temp_storage_bytes, d_in, d_out, static_cast(num_items), init, stream, transform_op); + return detail::rfa::dispatch( + d_temp_storage, + temp_storage_bytes, + d_in, + d_out, + static_cast(num_items), + init, + stream, + transform_op, + policy_selector{}); } template ; - using accum_t = ::cuda::std::__accumulator_t, T>; - - using reduce_tuning_t = ::cuda::std::execution::__query_result_or_t< - TuningEnvT, - detail::reduce::get_tuning_query_t, - detail::reduce::policy_selector_from_types>; + using offset_t = detail::choose_offset_t; + using accum_t = ::cuda::std::__accumulator_t, T>; + using default_policy_selector = detail::reduce::policy_selector_from_types; + using policy_selector = + ::cuda::std::execution::__query_result_or_t; return detail::reduce::dispatch_nondeterministic( d_temp_storage, @@ -232,7 +217,7 @@ private: init, stream, transform_op, - reduce_tuning_t{}); + policy_selector{}); } public: diff --git a/cub/cub/device/device_transform.cuh b/cub/cub/device/device_transform.cuh index 74bd0ecdfac..13051f03d1f 100644 --- a/cub/cub/device/device_transform.cuh +++ b/cub/cub/device/device_transform.cuh @@ -46,27 +46,6 @@ struct ::cuda::proclaims_copyable_arguments -// TODO(bgruber): we cannot check the concept here because PolicySelector is usually an incomplete type still -// #if _CCCL_HAS_CONCEPTS() -// requires transform_policy_selector -// #endif // _CCCL_HAS_CONCEPTS() -struct tuning -{ - [[nodiscard]] _CCCL_TRIVIAL_API constexpr auto query(const get_tuning_query_t&) const noexcept -> PolicySelector - { - return static_cast(*this); - } -}; -} // namespace detail::transform - //! DeviceTransform provides device-wide, parallel operations for transforming elements tuple-wise from multiple input //! sequences into an output sequence. struct DeviceTransform @@ -105,8 +84,9 @@ private: ::cuda::std::is_same_v, ::cuda::std::tuple, RandomAccessIteratorOut>; + using policy_selector = ::cuda::std::execution:: - __query_result_or_t; + __query_result_or_t; #if _CCCL_HAS_CONCEPTS() static_assert(detail::transform::transform_policy_selector); diff --git a/cub/cub/device/dispatch/dispatch_segmented_radix_sort.cuh b/cub/cub/device/dispatch/dispatch_segmented_radix_sort.cuh index 29b813a3414..b51b942189c 100644 --- a/cub/cub/device/dispatch/dispatch_segmented_radix_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_segmented_radix_sort.cuh @@ -861,17 +861,17 @@ template , - typename KernelSource = DeviceSegmentedRadixSortKernelSource< - PolicySelector, - Order, - KeyT, - ValueT, - BeginOffsetIteratorT, - EndOffsetIteratorT, - SegmentSizeT, - DecomposerT>, + typename DecomposerT = identity_decomposer_t, + typename PolicySelector, + typename KernelSource = DeviceSegmentedRadixSortKernelSource< + PolicySelector, + Order, + KeyT, + ValueT, + BeginOffsetIteratorT, + EndOffsetIteratorT, + SegmentSizeT, + DecomposerT>, typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY> #if _CCCL_HAS_CONCEPTS() requires radix_sort_policy_selector diff --git a/cub/test/catch2_test_device_radix_sort_env.cu b/cub/test/catch2_test_device_radix_sort_env.cu index bff22a30190..f8ee35b143b 100644 --- a/cub/test/catch2_test_device_radix_sort_env.cu +++ b/cub/test/catch2_test_device_radix_sort_env.cu @@ -390,3 +390,162 @@ TEST_CASE("Device radix sort keys descending uses custom stream", "[radix_sort][ REQUIRE(keys_out == expected_keys); REQUIRE(cudaSuccess == cudaStreamDestroy(custom_stream)); } + +// using different block sizes yields to different temporary storage sizes, so use a custom policy to influence that +template +struct tiny_onesweep_policy_selector +{ + _CCCL_API constexpr auto operator()(cuda::arch_id arch) const -> cub::detail::radix_sort::radix_sort_policy + { + using default_selector_t = cub::detail::radix_sort::policy_selector_from_types; + auto policy = default_selector_t{}(arch); + policy.use_onesweep = true; + policy.onesweep.block_threads = BlockThreads; + policy.onesweep.items_per_thread = 1; + return policy; + } +}; + +template +std::size_t measure_allocated_bytes(CallableT&& run, PolicySelector policy_selector) +{ + cuda::stream_ref stream{cudaStream_t{}}; + size_t bytes_allocated = 0; + size_t bytes_deallocated = 0; + auto env = stdexec::env{ + cuda::std::execution::prop{cuda::mr::__get_memory_resource_t{}, + device_memory_resource{{}, stream.get(), &bytes_allocated, &bytes_deallocated}}, + cuda::std::execution::prop{cuda::get_stream_t{}, cuda::stream_ref{stream}}, + cuda::execution::__tune(policy_selector)}; + REQUIRE(cudaSuccess == run(env)); + stream.sync(); + CHECK(bytes_allocated > 0); + CHECK(bytes_allocated == bytes_deallocated); + return bytes_allocated; +} + +TEST_CASE("DeviceRadixSort::SortPairs can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + return cub::DeviceRadixSort::SortPairs( + data.data().get(), + data.data().get(), + data.data().get(), + data.data().get(), + static_cast(data.size()), + 0, + 32, + env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortPairs DoubleBuffer can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + cub::DoubleBuffer double_buf(data.data().get(), data.data().get()); + return cub::DeviceRadixSort::SortPairs(double_buf, double_buf, static_cast(data.size()), 0, 32, env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortPairsDescending can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + return cub::DeviceRadixSort::SortPairsDescending( + data.data().get(), + data.data().get(), + data.data().get(), + data.data().get(), + static_cast(data.size()), + 0, + 32, + env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortPairsDescending DoubleBuffer can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + cub::DoubleBuffer double_buf(data.data().get(), data.data().get()); + return cub::DeviceRadixSort::SortPairsDescending(double_buf, double_buf, static_cast(data.size()), 0, 32, env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortKeys can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + return cub::DeviceRadixSort::SortKeys( + data.data().get(), data.data().get(), static_cast(data.size()), 0, 32, env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortKeys DoubleBuffer can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + cub::DoubleBuffer double_buf(data.data().get(), data.data().get()); + return cub::DeviceRadixSort::SortKeys(double_buf, static_cast(data.size()), 0, 32, env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortKeysDescending can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + return cub::DeviceRadixSort::SortKeysDescending( + data.data().get(), data.data().get(), static_cast(data.size()), 0, 32, env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} + +TEST_CASE("DeviceRadixSort::SortKeysDescending DoubleBuffer can be tuned", "[radix_sort][device]") +{ + auto l = [&](auto env) { + auto data = c2h::device_vector(10'000); // must be larger than the single tile path + cub::DoubleBuffer double_buf(data.data().get(), data.data().get()); + return cub::DeviceRadixSort::SortKeysDescending(double_buf, static_cast(data.size()), 0, 32, env); + }; + + auto default_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + auto tuned_bytes = measure_allocated_bytes(l, tiny_onesweep_policy_selector{}); + + CHECK(default_bytes != tuned_bytes); +} diff --git a/cub/test/catch2_test_device_reduce_env.cu b/cub/test/catch2_test_device_reduce_env.cu index 13625e85807..954f7b8454d 100644 --- a/cub/test/catch2_test_device_reduce_env.cu +++ b/cub/test/catch2_test_device_reduce_env.cu @@ -105,7 +105,7 @@ TEST_CASE("Device sum works with default environment", "[reduce][device]") } template -struct reduce_tuning : cub::detail::reduce::tuning> +struct reduce_tuning { _CCCL_API constexpr auto operator()(cuda::arch_id /*arch*/) const -> cub::detail::reduce::reduce_policy { @@ -115,20 +115,16 @@ struct reduce_tuning : cub::detail::reduce::tuning> } }; -struct get_scan_tuning_query_t +struct unrelated_policy {}; -struct scan_tuning +struct unrelated_tuning { - [[nodiscard]] _CCCL_NODEBUG_API constexpr auto query(const get_scan_tuning_query_t&) const noexcept + // should never be called + auto operator()(cuda::arch_id /*arch*/) const -> unrelated_policy { - return *this; + throw 1337; } - - // Make sure this is not used - template - struct fn - {}; }; using block_sizes = c2h::type_list, cuda::std::integral_constant>; @@ -143,8 +139,8 @@ C2H_TEST("Device reduce can be tuned", "[reduce][device]", block_sizes) auto d_in = cuda::constant_iterator(1); auto d_out = thrust::device_vector(1); - // We are expecting that `scan_tuning` is ignored - auto env = cuda::execution::__tune(reduce_tuning{}, scan_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(reduce_tuning{}, unrelated_tuning{}); REQUIRE(cudaSuccess == cub::DeviceReduce::Reduce(d_in, d_out.begin(), num_items, block_size_check, 0, env)); REQUIRE(d_out[0] == num_items); @@ -159,8 +155,8 @@ C2H_TEST("Device sum can be tuned", "[reduce][device]", block_sizes) auto d_in = cuda::constant_iterator(1); auto d_out = thrust::device_vector(1); - // We are expecting that `scan_tuning` is ignored - auto env = cuda::execution::__tune(reduce_tuning{}, scan_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(reduce_tuning{}, unrelated_tuning{}); REQUIRE(cudaSuccess == cub::DeviceReduce::Sum(d_in, d_out.begin(), num_items, env)); REQUIRE(d_out[0] == num_items); diff --git a/cub/test/catch2_test_device_reduce_nondeterministic.cu b/cub/test/catch2_test_device_reduce_nondeterministic.cu index 9d3c00ea423..a85c794fa7b 100644 --- a/cub/test/catch2_test_device_reduce_nondeterministic.cu +++ b/cub/test/catch2_test_device_reduce_nondeterministic.cu @@ -29,42 +29,19 @@ using float_type_list = #endif >; -template -struct AgentReducePolicy -{ - /// Number of items per vectorized load - static constexpr int VECTOR_LOAD_LENGTH = 4; - - /// Cooperative block-wide reduction algorithm to use - static constexpr cub::BlockReduceAlgorithm BLOCK_ALGORITHM = - cub::BlockReduceAlgorithm::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC; - - /// Cache load modifier for reading input elements - static constexpr cub::CacheLoadModifier LOAD_MODIFIER = cub::CacheLoadModifier::LOAD_DEFAULT; - constexpr static int ITEMS_PER_THREAD = NOMINAL_ITEMS_PER_THREAD_4B; - constexpr static int BLOCK_THREADS = NOMINAL_BLOCK_THREADS_4B; -}; - template -struct hub_t +struct custom_policy_selector { - struct Policy : cub::ChainedPolicy<300, Policy, Policy> + _CCCL_API constexpr auto operator()(::cuda::arch_id) const -> cub::detail::reduce::reduce_policy { - constexpr static int ITEMS_PER_THREAD = ItemsPerThread; - - using ReducePolicy = AgentReducePolicy; - - // SingleTilePolicy - using SingleTilePolicy = ReducePolicy; - - // SegmentedReducePolicy - using SegmentedReducePolicy = ReducePolicy; - - // ReduceNondeterministicPolicy - using ReduceNondeterministicPolicy = ReducePolicy; - }; - - using MaxPolicy = Policy; + auto rp = cub::detail::reduce::agent_reduce_policy{ + BlockSize, + ItemsPerThread, + 4, + cub::BlockReduceAlgorithm::BLOCK_REDUCE_WARP_REDUCTIONS_NONDETERMINISTIC, + cub::CacheLoadModifier::LOAD_DEFAULT}; + return {rp, rp, rp}; + } }; C2H_TEST("Nondeterministic Device reduce works with float and double on gpu", @@ -157,11 +134,11 @@ C2H_TEST("Nondeterministic Device reduce works with float and double on gpu with c2h::device_vector d_output_p1(1); c2h::device_vector d_output_p2(1); - auto env1 = cuda::std::execution::env{ - cuda::execution::require(cuda::execution::determinism::not_guaranteed), cuda::execution::__tune(hub_t<1, 128>{})}; + auto env1 = cuda::std::execution::env{cuda::execution::require(cuda::execution::determinism::not_guaranteed), + cuda::execution::__tune(custom_policy_selector<1, 128>{})}; - auto env2 = cuda::std::execution::env{ - cuda::execution::require(cuda::execution::determinism::not_guaranteed), cuda::execution::__tune(hub_t<2, 256>{})}; + auto env2 = cuda::std::execution::env{cuda::execution::require(cuda::execution::determinism::not_guaranteed), + cuda::execution::__tune(custom_policy_selector<2, 256>{})}; REQUIRE( cudaSuccess == cub::DeviceReduce::Reduce(d_input.begin(), d_output_p1.begin(), num_items, min_op, init, env1)); diff --git a/cub/test/catch2_test_device_scan_env.cu b/cub/test/catch2_test_device_scan_env.cu index e8a692b6051..d490142bd79 100644 --- a/cub/test/catch2_test_device_scan_env.cu +++ b/cub/test/catch2_test_device_scan_env.cu @@ -113,6 +113,7 @@ TEST_CASE("Device scan exclusive sum works with default environment", "[sum][dev REQUIRE(d_out[0] == value_t{0}); } +// TODO(bgruber): convert to the new tuning API template struct scan_tuning : cub::detail::scan::tuning> { @@ -143,24 +144,16 @@ struct scan_tuning : cub::detail::scan::tuning> }; }; -struct get_reduce_tuning_query_t +struct unrelated_policy {}; -struct reduce_tuning +struct unrelated_tuning { - [[nodiscard]] _CCCL_NODEBUG_API constexpr auto query(const get_reduce_tuning_query_t&) const noexcept + // should never be called + auto operator()(cuda::arch_id /*arch*/) const -> unrelated_policy { - return *this; + throw 1337; } - - // Make sure this is not used - template - struct fn - {}; }; using block_sizes = c2h::type_list, cuda::std::integral_constant>; @@ -175,8 +168,8 @@ C2H_TEST("Device scan exclusive-scan can be tuned", "[scan][device]", block_size auto d_in = cuda::constant_iterator(1); auto d_out = thrust::device_vector(num_items); - // We are expecting that `reduce_tuning` is ignored - auto env = cuda::execution::__tune(scan_tuning{}, reduce_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(scan_tuning{}, unrelated_tuning{}); REQUIRE(cudaSuccess == cub::DeviceScan::ExclusiveScan(d_in, d_out.begin(), block_size_check, 0, num_items, env)); @@ -195,8 +188,8 @@ C2H_TEST("Device scan exclusive-sum can be tuned", "[scan][device]", block_sizes auto d_in = cuda::constant_iterator(1); auto d_out = thrust::device_vector(num_items); - // We are expecting that `reduce_tuning` is ignored - auto env = cuda::execution::__tune(scan_tuning{}, reduce_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(scan_tuning{}, unrelated_tuning{}); REQUIRE(cudaSuccess == cub::DeviceScan::ExclusiveSum(d_in, d_out.begin(), num_items, env)); @@ -248,8 +241,8 @@ C2H_TEST("Device scan inclusive-scan can be tuned", "[scan][device]", block_size auto d_in = cuda::constant_iterator(1); auto d_out = thrust::device_vector(num_items); - // We are expecting that `reduce_tuning` is ignored - auto env = cuda::execution::__tune(scan_tuning{}, reduce_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(scan_tuning{}, unrelated_tuning{}); REQUIRE(cudaSuccess == cub::DeviceScan::InclusiveScan(d_in, d_out.begin(), block_size_check, num_items, env)); @@ -291,8 +284,8 @@ C2H_TEST("Device scan inclusive-scan-init can be tuned", "[scan][device]", block int init{10}; - // We are expecting that `reduce_tuning` is ignored - auto env = cuda::execution::__tune(scan_tuning{}, reduce_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(scan_tuning{}, unrelated_tuning{}); REQUIRE( cudaSuccess == cub::DeviceScan::InclusiveScanInit(d_in, d_out.begin(), block_size_check, init, num_items, env)); diff --git a/cub/test/catch2_test_device_segmented_reduce_env.cu b/cub/test/catch2_test_device_segmented_reduce_env.cu index 2f9f7ee30f3..72df98e3ba9 100644 --- a/cub/test/catch2_test_device_segmented_reduce_env.cu +++ b/cub/test/catch2_test_device_segmented_reduce_env.cu @@ -10,47 +10,26 @@ #include template -struct reduce_tuning : cub::detail::reduce::tuning> +struct reduce_tuning { - template - struct fn + _CCCL_API constexpr auto operator()(::cuda::arch_id) const -> cub::detail::reduce::reduce_policy { - struct Policy500 : cub::ChainedPolicy<500, Policy500, Policy500> - { - struct ReducePolicy - { - static constexpr int VECTOR_LOAD_LENGTH = 1; - - static constexpr cub::BlockReduceAlgorithm BLOCK_ALGORITHM = cub::BLOCK_REDUCE_WARP_REDUCTIONS; - - static constexpr cub::CacheLoadModifier LOAD_MODIFIER = cub::LOAD_DEFAULT; - - static constexpr int ITEMS_PER_THREAD = 1; - static constexpr int BLOCK_THREADS = BlockThreads; - }; - - using SingleTilePolicy = ReducePolicy; - using SegmentedReducePolicy = ReducePolicy; - }; - - using MaxPolicy = Policy500; - }; + auto rp = cub::detail::reduce::agent_reduce_policy{ + BlockThreads, 1, 1, cub::BLOCK_REDUCE_WARP_REDUCTIONS, cub::LOAD_DEFAULT}; + return {rp, rp, rp}; + } }; -struct get_scan_tuning_query_t +struct unrelated_policy {}; -struct scan_tuning +struct unrelated_tuning { - [[nodiscard]] _CCCL_NODEBUG_API constexpr auto query(const get_scan_tuning_query_t&) const noexcept + // should never be called + auto operator()(cuda::arch_id /*arch*/) const -> unrelated_policy { - return *this; + throw 1337; } - - // Make sure this is not used - template - struct fn - {}; }; using block_sizes = c2h::type_list, cuda::std::integral_constant>; @@ -65,8 +44,8 @@ C2H_TEST("Device segmented sum can be tuned", "[reduce][device]", block_sizes) thrust::device_vector d_in{8, 6, 7, 5, 3, 0, 9}; thrust::device_vector d_out(3); - // We are expecting that `scan_tuning` is ignored - auto env = cuda::execution::__tune(reduce_tuning{}, scan_tuning{}); + // We are expecting that `unrelated_tuning` is ignored + auto env = cuda::execution::__tune(reduce_tuning{}, unrelated_tuning{}); auto error = cub::DeviceSegmentedReduce::Sum(d_in.begin(), d_out.begin(), num_segments, d_offsets_it, d_offsets_it + 1, env); diff --git a/cub/test/catch2_test_device_transform_env.cu b/cub/test/catch2_test_device_transform_env.cu index f9e3b8c3fb8..642392ea113 100644 --- a/cub/test/catch2_test_device_transform_env.cu +++ b/cub/test/catch2_test_device_transform_env.cu @@ -224,8 +224,7 @@ C2H_TEST("DeviceTransform::TransformStableArgumentAddresses custom stream", "[de } // use a policy selector that prescribes to run with exactly 8 threads per block and 3 items per thread -// TODO(bgruber): can we get by without the base class? -struct my_policy_selector : cub::detail::transform::tuning +struct my_policy_selector { _CCCL_API constexpr auto operator()(cuda::arch_id) const -> cub::detail::transform::transform_policy { diff --git a/libcudacxx/include/cuda/__execution/tune.h b/libcudacxx/include/cuda/__execution/tune.h index b273930dd87..fb009470c86 100644 --- a/libcudacxx/include/cuda/__execution/tune.h +++ b/libcudacxx/include/cuda/__execution/tune.h @@ -21,6 +21,7 @@ # pragma system_header #endif // no system header +#include #include #include #include @@ -51,16 +52,18 @@ struct __get_tuning_t _CCCL_GLOBAL_CONSTANT auto __get_tuning = __get_tuning_t{}; -template -[[nodiscard]] _CCCL_NODEBUG_API auto __tune(_Tunings...) +template +[[nodiscard]] _CCCL_NODEBUG_API auto __tune(_PolicySelectors...) { - static_assert((::cuda::std::is_empty_v<_Tunings> && ...), "Stateful tunings are not implemented"); + static_assert((::cuda::std::is_empty_v<_PolicySelectors> && ...), "Stateful policy selectors are not implemented"); - // clang < 19 doesn't like this code // since all the tunings are stateless, let's ignore incoming parameters - ::cuda::std::execution::env<_Tunings...> __env{}; - return ::cuda::std::execution::prop{__get_tuning_t{}, __env}; + // we use the return type of the policy_selector as tag + using tuning_env = ::cuda::std::execution::env< + ::cuda::std::execution::prop...>; + + return ::cuda::std::execution::prop{__get_tuning_t{}, tuning_env{}}; } _CCCL_END_NAMESPACE_CUDA_EXECUTION diff --git a/libcudacxx/test/libcudacxx/cuda/execution/tune.pass.cpp b/libcudacxx/test/libcudacxx/cuda/execution/tune.pass.cpp index 5944599ed2c..088670a862d 100644 --- a/libcudacxx/test/libcudacxx/cuda/execution/tune.pass.cpp +++ b/libcudacxx/test/libcudacxx/cuda/execution/tune.pass.cpp @@ -10,54 +10,31 @@ #include -struct get_reduce_tuning_query_t -{}; +struct reduce_policy +{ + int block_threads; +}; -template -struct reduce_tuning +template +struct reduce_policy_selector { - [[nodiscard]] _CCCL_NODEBUG_API constexpr auto query(const get_reduce_tuning_query_t&) const noexcept -> Derived + _CCCL_API constexpr auto operator()(cuda::arch_id /*arch*/) const -> reduce_policy { - return static_cast(*this); + return {BlockThreads / sizeof(T)}; } }; -template -struct reduce : reduce_tuning> +struct scan_policy { - template - struct type - { - struct max_policy - { - struct reduce_policy - { - static constexpr int block_threads = BlockThreads / sizeof(T); - }; - }; - }; + int block_threads = 1; }; -struct get_scan_tuning_query_t -{}; - -struct scan_tuning +struct scan_policy_selector { - [[nodiscard]] _CCCL_NODEBUG_API constexpr auto query(const get_scan_tuning_query_t&) const noexcept + _CCCL_API constexpr auto operator()(cuda::arch_id /*arch*/) const -> scan_policy { - return *this; + return {}; } - - struct type - { - struct max_policy - { - struct reduce_policy - { - static constexpr int block_threads = 1; - }; - }; - }; }; __host__ __device__ void test() @@ -65,15 +42,13 @@ __host__ __device__ void test() constexpr int nominal_block_threads = 256; constexpr int block_threads = nominal_block_threads / sizeof(int); - using env_t = decltype(cuda::execution::__tune(reduce{}, scan_tuning{})); + using env_t = decltype(cuda::execution::__tune(reduce{}, scan_tuning{})); using tuning_t = cuda::std::execution::__query_result_t; - using reduce_tuning_t = cuda::std::execution::__query_result_t; - using scan_tuning_t = cuda::std::execution::__query_result_t; - using reduce_policy_t = reduce_tuning_t::type; - using scan_policy_t = scan_tuning_t::type; + using reduce_policy_t = cuda::std::execution::__query_result_t; + using scan_policy_t = cuda::std::execution::__query_result_t; - static_assert(reduce_policy_t::max_policy::reduce_policy::block_threads == block_threads); - static_assert(scan_policy_t::max_policy::reduce_policy::block_threads == 1); + static_assert(reduce_policy_t{}(cuda::arch_id::sm_75). : block_threads == block_threads); + static_assert(scan_policy_t{}(cuda::arch_id::sm_75). : block_threads == 1); } int main(int, char**)