1010
1111#include < cub/detail/choose_offset.cuh> // cub::detail::choose_offset_t
1212#include < cub/detail/launcher/cuda_driver.cuh> // cub::detail::CudaDriverLauncherFactory
13- #include < cub/detail/ptx-json-parser.cuh>
1413#include < cub/device/dispatch/dispatch_segmented_sort.cuh> // cub::DispatchSegmentedSort
1514#include < cub/device/dispatch/kernels/kernel_segmented_sort.cuh> // DeviceSegmentedSort kernels
16- #include < cub/device/dispatch/tuning/tuning_segmented_sort.cuh> // policy_hub
15+ #include < cub/device/dispatch/tuning/tuning_segmented_sort.cuh>
1716#include < cub/thread/thread_load.cuh> // cub::LoadModifier
1817
1918#include < exception> // std::exception
3231#include " util/types.h"
3332#include < cccl/c/segmented_sort.h>
3433#include < cccl/c/types.h> // cccl_type_info
35- #include < nlohmann/json.hpp>
3634#include < nvrtc/command_list.h>
3735#include < nvrtc/ltoir_list_appender.h>
3836#include < util/build_utils.h>
3937
4038struct device_segmented_sort_policy_selector ;
41- struct device_three_way_partition_policy ;
39+ struct device_three_way_partition_policy_selector ;
4240using OffsetT = ptrdiff_t ;
4341static_assert (std::is_same_v<cub::detail::choose_signed_offset_t <OffsetT>, OffsetT>, " OffsetT must be long" );
4442
@@ -296,8 +294,8 @@ std::string get_three_way_partition_init_kernel_name()
296294
297295std::string get_three_way_partition_kernel_name (std::string_view large_selector_t , std::string_view small_selector_t )
298296{
299- std::string chained_policy_t ;
300- check (cccl_type_name_from_nvrtc<device_three_way_partition_policy >(&chained_policy_t ));
297+ std::string policy_selector_t ;
298+ check (cccl_type_name_from_nvrtc<device_three_way_partition_policy_selector >(&policy_selector_t ));
301299
302300 static constexpr std::string_view input_it_t =
303301 " thrust::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>" ;
@@ -317,7 +315,7 @@ std::string get_three_way_partition_kernel_name(std::string_view large_selector_
317315 return std::format (
318316 " cub::detail::three_way_partition::DeviceThreeWayPartitionKernel<{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, "
319317 " {10}>" ,
320- chained_policy_t , // 0 (ChainedPolicyT )
318+ policy_selector_t , // 0 (PolicySelector )
321319 input_it_t , // 1 (InputIteratorT)
322320 first_out_it_t , // 2 (FirstOutputIteratorT)
323321 second_out_it_t , // 3 (SecondOutputIteratorT)
@@ -348,55 +346,6 @@ struct partition_kernel_source
348346 return build.offset_type .size ;
349347 }
350348};
351-
352- struct partition_runtime_tuning_policy
353- {
354- cub::detail::RuntimeThreeWayPartitionAgentPolicy three_way_partition;
355-
356- auto ThreeWayPartition () const
357- {
358- return three_way_partition;
359- }
360-
361- using MaxPolicy = partition_runtime_tuning_policy;
362-
363- template <typename F>
364- cudaError_t Invoke (int , F& op)
365- {
366- return op.template Invoke <partition_runtime_tuning_policy>(*this );
367- }
368- };
369-
370- std::string get_three_way_partition_policy_delay_constructor (const nlohmann::json& partition_policy)
371- {
372- auto delay_ctor_info = partition_policy[" DelayConstructor" ];
373-
374- std::string delay_ctor_params;
375- for (auto && param : delay_ctor_info[" params" ])
376- {
377- delay_ctor_params.append (to_string (param) + " , " );
378- }
379- delay_ctor_params.erase (delay_ctor_params.size () - 2 ); // remove last ", "
380-
381- return std::format (" cub::detail::{}<{}>" , delay_ctor_info[" name" ].get <std::string>(), delay_ctor_params);
382- }
383-
384- std::string inject_delay_constructor_into_three_way_policy (
385- const std::string& three_way_partition_policy_str, const std::string& delay_constructor_type)
386- {
387- // Insert before the final closing of the struct (right before the sequence "};")
388- static constexpr std::string_view needle = " };" ;
389- const auto pos = three_way_partition_policy_str.rfind (needle);
390- if (pos == std::string::npos)
391- {
392- return three_way_partition_policy_str; // unexpected; return as-is
393- }
394- const std::string insertion =
395- std::format (" \n struct detail {{ using delay_constructor_t = {}; }}; \n " , delay_constructor_type);
396- std::string out = three_way_partition_policy_str;
397- out.insert (pos, insertion);
398- return out;
399- }
400349} // namespace segmented_sort
401350
402351struct segmented_sort_keys_input_iterator_tag ;
467416 const auto [end_offset_iterator_name, end_offset_iterator_src] =
468417 get_specialization<segmented_sort_end_offset_iterator_tag>(template_id<input_iterator_traits>(), end_offset_it);
469418
470- const auto offset_t = cccl_type_enum_to_name (cccl_type_enum::CCCL_INT64);
471-
472419 const std::string key_t = cccl_type_enum_to_name (keys_in_it.value_type .type );
473420 const std::string value_t =
474421 keys_only ? " cub::NullType" : cccl_type_enum_to_name<items_storage_t >(values_in_it.value_type .type );
543490 key_t , // 0
544491 value_t ); // 1
545492
546- static constexpr std::string_view three_way_partition_policy_hub_expr =
547- " cub::detail::three_way_partition::policy_hub<cub::detail::segmented_sort::local_segment_index_t, "
548- " cub::detail::three_way_partition::per_partition_offset_t>" ;
493+ const auto partition_policy_sel = cub::detail::three_way_partition::policy_selector{
494+ cub::detail::classify_type<cub::detail::segmented_sort::local_segment_index_t >,
495+ int {sizeof (cub::detail::segmented_sort::local_segment_index_t )},
496+ int {sizeof (cub::detail::three_way_partition::per_partition_offset_t )}};
497+
498+ // TODO(bgruber): drop this if tuning policies become formattable
499+ std::stringstream partition_policy_sel_str;
500+ partition_policy_sel_str << partition_policy_sel (cuda::to_arch_id (cuda::compute_capability{cc_major, cc_minor}));
501+
502+ const auto three_way_partition_policy_expr = std::format (
503+ " cub::detail::three_way_partition::policy_selector_from_types<cub::detail::segmented_sort::local_segment_index_t, "
504+ " cub::detail::three_way_partition::per_partition_offset_t>" );
549505
550506 const std::string final_src = std::format (
551507 R"XXX(
@@ -573,18 +529,17 @@ struct __align__({4}) items_storage_t {{
573529{9}
574530{10}
575531using device_segmented_sort_policy_selector = {11};
532+ using device_three_way_partition_policy_selector = {13};
576533using namespace cub;
534+ using namespace cub::detail;
577535using namespace cub::detail::segmented_sort;
536+ using namespace cub::detail::three_way_partition;
578537static_assert(
579538 device_segmented_sort_policy_selector()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {12},
580539 "Host generated and JIT compiled policy mismatch");
581- using device_three_way_partition_policy = {13}::MaxPolicy;
582-
583- #include <cub/detail/ptx-json/json.cuh>
584- __device__ consteval auto& three_way_partition_policy_generator() {{
585- return ptx_json::id<ptx_json::string("device_three_way_partition_policy")>()
586- = cub::detail::three_way_partition::ThreeWayPartitionPolicyWrapper<device_three_way_partition_policy::ActivePolicy>::EncodedPolicy();
587- }}
540+ static_assert(
541+ device_three_way_partition_policy_selector()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {14},
542+ "Host generated and JIT compiled three-way partition policy mismatch");
588543)XXX" ,
589544 jit_template_header_contents, // 0
590545 keys_in_it.value_type .size , // 1
@@ -599,7 +554,8 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
599554 small_selector_src, // 10
600555 segmented_sort_policy_expr, // 11
601556 segmented_sort_policy_sel_str.view (), // 12
602- three_way_partition_policy_hub_expr); // 13
557+ three_way_partition_policy_expr, // 13
558+ partition_policy_sel_str.view ()); // 14
603559
604560#if false // CCCL_DEBUGGING_SWITCH
605561 fflush (stderr);
@@ -617,7 +573,6 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
617573 " -dlto" ,
618574 " -default-device" ,
619575 " -DCUB_DISABLE_CDP" ,
620- " -DCUB_ENABLE_POLICY_PTX_JSON" , // TODO(bgruber): remove after we ported three way partition to the new tuning API
621576 " -std=c++20" };
622577
623578 cccl::detail::extend_args_with_build_config (args, config);
@@ -691,25 +646,16 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
691646 check (cuLibraryGetKernel (
692647 &build_ptr->three_way_partition_kernel , build_ptr->library , three_way_partition_kernel_lowered_name.c_str ()));
693648
694- // TODO(bgruber): convert to the new tuning API
695- nlohmann::json partition_policy =
696- cub::detail::ptx_json::parse (" device_three_way_partition_policy" , {result.data .get (), result.size });
697-
698- using cub::detail::RuntimeThreeWayPartitionAgentPolicy;
699- auto three_way_partition_policy =
700- RuntimeThreeWayPartitionAgentPolicy::from_json (partition_policy, " ThreeWayPartitionPolicy" );
701-
702649 build_ptr->cc = cc_major * 10 + cc_minor;
703650 build_ptr->large_segments_selector_op = large_selector_op;
704651 build_ptr->small_segments_selector_op = small_selector_op;
705652 build_ptr->cubin = (void *) result.data .release ();
706653 build_ptr->cubin_size = result.size ;
707654 build_ptr->key_type = keys_in_it.value_type ;
708655 build_ptr->offset_type = cccl_type_info{sizeof (OffsetT), alignof (OffsetT), cccl_type_enum::CCCL_INT64};
709- // Use the runtime policy extracted via from_json
710- build_ptr->runtime_policy = new cub::detail::segmented_sort::policy_selector{policy_sel};
711- build_ptr->partition_runtime_policy = new segmented_sort::partition_runtime_tuning_policy{three_way_partition_policy};
712- build_ptr->order = sort_order;
656+ build_ptr->runtime_policy = new cub::detail::segmented_sort::policy_selector{policy_sel};
657+ build_ptr->partition_runtime_policy = new cub::detail::three_way_partition::policy_selector{partition_policy_sel};
658+ build_ptr->order = sort_order;
713659
714660 return CUDA_SUCCESS;
715661}
@@ -798,34 +744,24 @@ CUresult cccl_device_segmented_sort_impl(
798744 cub::DoubleBuffer<indirect_arg_t > d_values_double_buffer (
799745 *static_cast <indirect_arg_t **>(&val_arg_in), *static_cast <indirect_arg_t **>(&val_arg_out));
800746
801- // TODO(bgruber): remove all template arguments except the first two (the others can be deduced)
802- auto exec_status = cub::detail::segmented_sort::dispatch<
803- Order,
804- OffsetT, // OffsetT
805- indirect_arg_t , // KeyT
806- indirect_arg_t , // ValueT
807- indirect_iterator_t , // BeginOffsetIteratorT
808- indirect_iterator_t , // EndOffsetIteratorT
809- cub::detail::segmented_sort::policy_selector, // PolicySelector
810- segmented_sort::segmented_sort_kernel_source, // KernelSource
811- segmented_sort::partition_runtime_tuning_policy // PartitionPolicyHub
812- >(d_temp_storage,
813- *temp_storage_bytes,
814- d_keys_double_buffer,
815- d_values_double_buffer,
816- num_items,
817- num_segments,
818- indirect_iterator_t {start_offset_in},
819- indirect_iterator_t {end_offset_in},
820- is_overwrite_okay,
821- stream,
822- /* policy_selector */
823- *static_cast <cub::detail::segmented_sort::policy_selector*>(build.runtime_policy ),
824- /* partition_max_policy */
825- *static_cast <segmented_sort::partition_runtime_tuning_policy*>(build.partition_runtime_policy ),
826- /* kernel_source */ segmented_sort::segmented_sort_kernel_source{build},
827- /* partition_kernel_source */ segmented_sort::partition_kernel_source{build},
828- /* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc });
747+ auto exec_status = cub::detail::segmented_sort::dispatch<Order, OffsetT>(
748+ d_temp_storage,
749+ *temp_storage_bytes,
750+ d_keys_double_buffer,
751+ d_values_double_buffer,
752+ num_items,
753+ num_segments,
754+ indirect_iterator_t {start_offset_in},
755+ indirect_iterator_t {end_offset_in},
756+ is_overwrite_okay,
757+ stream,
758+ /* policy_selector */
759+ *static_cast <cub::detail::segmented_sort::policy_selector*>(build.runtime_policy ),
760+ /* partition_policy_selector */
761+ *static_cast <cub::detail::three_way_partition::policy_selector*>(build.partition_runtime_policy ),
762+ /* kernel_source */ segmented_sort::segmented_sort_kernel_source{build},
763+ /* partition_kernel_source */ segmented_sort::partition_kernel_source{build},
764+ /* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc });
829765
830766 *selector = d_keys_double_buffer.selector ;
831767 error = static_cast <CUresult>(exec_status);
909845 // Clean up the runtime policies
910846 std::unique_ptr<cub::detail::segmented_sort::policy_selector> rtp (
911847 static_cast <cub::detail::segmented_sort::policy_selector*>(build_ptr->runtime_policy ));
912- std::unique_ptr<segmented_sort::partition_runtime_tuning_policy > prtp (
913- static_cast <segmented_sort::partition_runtime_tuning_policy *>(build_ptr->partition_runtime_policy ));
848+ std::unique_ptr<cub::detail::three_way_partition::policy_selector > prtp (
849+ static_cast <cub::detail::three_way_partition::policy_selector *>(build_ptr->partition_runtime_policy ));
914850 check (cuLibraryUnload (build_ptr->library ));
915851
916852 return CUDA_SUCCESS;
0 commit comments