Skip to content

Commit 3ed7d1b

Browse files
Implement the new tuning API for DispatchThreeWayPartitionIf (#7900)
Fixes: #7646
1 parent 72c1d9d commit 3ed7d1b

File tree

10 files changed

+750
-390
lines changed

10 files changed

+750
-390
lines changed

c/parallel/src/segmented_sort.cu

Lines changed: 48 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010

1111
#include <cub/detail/choose_offset.cuh> // cub::detail::choose_offset_t
1212
#include <cub/detail/launcher/cuda_driver.cuh> // cub::detail::CudaDriverLauncherFactory
13-
#include <cub/detail/ptx-json-parser.cuh>
1413
#include <cub/device/dispatch/dispatch_segmented_sort.cuh> // cub::DispatchSegmentedSort
1514
#include <cub/device/dispatch/kernels/kernel_segmented_sort.cuh> // DeviceSegmentedSort kernels
16-
#include <cub/device/dispatch/tuning/tuning_segmented_sort.cuh> // policy_hub
15+
#include <cub/device/dispatch/tuning/tuning_segmented_sort.cuh>
1716
#include <cub/thread/thread_load.cuh> // cub::LoadModifier
1817

1918
#include <exception> // std::exception
@@ -32,13 +31,12 @@
3231
#include "util/types.h"
3332
#include <cccl/c/segmented_sort.h>
3433
#include <cccl/c/types.h> // cccl_type_info
35-
#include <nlohmann/json.hpp>
3634
#include <nvrtc/command_list.h>
3735
#include <nvrtc/ltoir_list_appender.h>
3836
#include <util/build_utils.h>
3937

4038
struct device_segmented_sort_policy_selector;
41-
struct device_three_way_partition_policy;
39+
struct device_three_way_partition_policy_selector;
4240
using OffsetT = ptrdiff_t;
4341
static_assert(std::is_same_v<cub::detail::choose_signed_offset_t<OffsetT>, OffsetT>, "OffsetT must be long");
4442

@@ -296,8 +294,8 @@ std::string get_three_way_partition_init_kernel_name()
296294

297295
std::string get_three_way_partition_kernel_name(std::string_view large_selector_t, std::string_view small_selector_t)
298296
{
299-
std::string chained_policy_t;
300-
check(cccl_type_name_from_nvrtc<device_three_way_partition_policy>(&chained_policy_t));
297+
std::string policy_selector_t;
298+
check(cccl_type_name_from_nvrtc<device_three_way_partition_policy_selector>(&policy_selector_t));
301299

302300
static constexpr std::string_view input_it_t =
303301
"thrust::counting_iterator<cub::detail::segmented_sort::local_segment_index_t>";
@@ -317,7 +315,7 @@ std::string get_three_way_partition_kernel_name(std::string_view large_selector_
317315
return std::format(
318316
"cub::detail::three_way_partition::DeviceThreeWayPartitionKernel<{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, "
319317
"{10}>",
320-
chained_policy_t, // 0 (ChainedPolicyT)
318+
policy_selector_t, // 0 (PolicySelector)
321319
input_it_t, // 1 (InputIteratorT)
322320
first_out_it_t, // 2 (FirstOutputIteratorT)
323321
second_out_it_t, // 3 (SecondOutputIteratorT)
@@ -348,55 +346,6 @@ struct partition_kernel_source
348346
return build.offset_type.size;
349347
}
350348
};
351-
352-
struct partition_runtime_tuning_policy
353-
{
354-
cub::detail::RuntimeThreeWayPartitionAgentPolicy three_way_partition;
355-
356-
auto ThreeWayPartition() const
357-
{
358-
return three_way_partition;
359-
}
360-
361-
using MaxPolicy = partition_runtime_tuning_policy;
362-
363-
template <typename F>
364-
cudaError_t Invoke(int, F& op)
365-
{
366-
return op.template Invoke<partition_runtime_tuning_policy>(*this);
367-
}
368-
};
369-
370-
std::string get_three_way_partition_policy_delay_constructor(const nlohmann::json& partition_policy)
371-
{
372-
auto delay_ctor_info = partition_policy["DelayConstructor"];
373-
374-
std::string delay_ctor_params;
375-
for (auto&& param : delay_ctor_info["params"])
376-
{
377-
delay_ctor_params.append(to_string(param) + ", ");
378-
}
379-
delay_ctor_params.erase(delay_ctor_params.size() - 2); // remove last ", "
380-
381-
return std::format("cub::detail::{}<{}>", delay_ctor_info["name"].get<std::string>(), delay_ctor_params);
382-
}
383-
384-
std::string inject_delay_constructor_into_three_way_policy(
385-
const std::string& three_way_partition_policy_str, const std::string& delay_constructor_type)
386-
{
387-
// Insert before the final closing of the struct (right before the sequence "};")
388-
static constexpr std::string_view needle = "};";
389-
const auto pos = three_way_partition_policy_str.rfind(needle);
390-
if (pos == std::string::npos)
391-
{
392-
return three_way_partition_policy_str; // unexpected; return as-is
393-
}
394-
const std::string insertion =
395-
std::format("\n struct detail {{ using delay_constructor_t = {}; }}; \n", delay_constructor_type);
396-
std::string out = three_way_partition_policy_str;
397-
out.insert(pos, insertion);
398-
return out;
399-
}
400349
} // namespace segmented_sort
401350

402351
struct segmented_sort_keys_input_iterator_tag;
@@ -467,8 +416,6 @@ try
467416
const auto [end_offset_iterator_name, end_offset_iterator_src] =
468417
get_specialization<segmented_sort_end_offset_iterator_tag>(template_id<input_iterator_traits>(), end_offset_it);
469418

470-
const auto offset_t = cccl_type_enum_to_name(cccl_type_enum::CCCL_INT64);
471-
472419
const std::string key_t = cccl_type_enum_to_name(keys_in_it.value_type.type);
473420
const std::string value_t =
474421
keys_only ? "cub::NullType" : cccl_type_enum_to_name<items_storage_t>(values_in_it.value_type.type);
@@ -543,9 +490,18 @@ try
543490
key_t, // 0
544491
value_t); // 1
545492

546-
static constexpr std::string_view three_way_partition_policy_hub_expr =
547-
"cub::detail::three_way_partition::policy_hub<cub::detail::segmented_sort::local_segment_index_t, "
548-
"cub::detail::three_way_partition::per_partition_offset_t>";
493+
const auto partition_policy_sel = cub::detail::three_way_partition::policy_selector{
494+
cub::detail::classify_type<cub::detail::segmented_sort::local_segment_index_t>,
495+
int{sizeof(cub::detail::segmented_sort::local_segment_index_t)},
496+
int{sizeof(cub::detail::three_way_partition::per_partition_offset_t)}};
497+
498+
// TODO(bgruber): drop this if tuning policies become formattable
499+
std::stringstream partition_policy_sel_str;
500+
partition_policy_sel_str << partition_policy_sel(cuda::to_arch_id(cuda::compute_capability{cc_major, cc_minor}));
501+
502+
const auto three_way_partition_policy_expr = std::format(
503+
"cub::detail::three_way_partition::policy_selector_from_types<cub::detail::segmented_sort::local_segment_index_t, "
504+
"cub::detail::three_way_partition::per_partition_offset_t>");
549505

550506
const std::string final_src = std::format(
551507
R"XXX(
@@ -573,18 +529,17 @@ struct __align__({4}) items_storage_t {{
573529
{9}
574530
{10}
575531
using device_segmented_sort_policy_selector = {11};
532+
using device_three_way_partition_policy_selector = {13};
576533
using namespace cub;
534+
using namespace cub::detail;
577535
using namespace cub::detail::segmented_sort;
536+
using namespace cub::detail::three_way_partition;
578537
static_assert(
579538
device_segmented_sort_policy_selector()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {12},
580539
"Host generated and JIT compiled policy mismatch");
581-
using device_three_way_partition_policy = {13}::MaxPolicy;
582-
583-
#include <cub/detail/ptx-json/json.cuh>
584-
__device__ consteval auto& three_way_partition_policy_generator() {{
585-
return ptx_json::id<ptx_json::string("device_three_way_partition_policy")>()
586-
= cub::detail::three_way_partition::ThreeWayPartitionPolicyWrapper<device_three_way_partition_policy::ActivePolicy>::EncodedPolicy();
587-
}}
540+
static_assert(
541+
device_three_way_partition_policy_selector()(::cuda::arch_id{{CUB_PTX_ARCH / 10}}) == {14},
542+
"Host generated and JIT compiled three-way partition policy mismatch");
588543
)XXX",
589544
jit_template_header_contents, // 0
590545
keys_in_it.value_type.size, // 1
@@ -599,7 +554,8 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
599554
small_selector_src, // 10
600555
segmented_sort_policy_expr, // 11
601556
segmented_sort_policy_sel_str.view(), // 12
602-
three_way_partition_policy_hub_expr); // 13
557+
three_way_partition_policy_expr, // 13
558+
partition_policy_sel_str.view()); // 14
603559

604560
#if false // CCCL_DEBUGGING_SWITCH
605561
fflush(stderr);
@@ -617,7 +573,6 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
617573
"-dlto",
618574
"-default-device",
619575
"-DCUB_DISABLE_CDP",
620-
"-DCUB_ENABLE_POLICY_PTX_JSON", // TODO(bgruber): remove after we ported three way partition to the new tuning API
621576
"-std=c++20"};
622577

623578
cccl::detail::extend_args_with_build_config(args, config);
@@ -691,25 +646,16 @@ __device__ consteval auto& three_way_partition_policy_generator() {{
691646
check(cuLibraryGetKernel(
692647
&build_ptr->three_way_partition_kernel, build_ptr->library, three_way_partition_kernel_lowered_name.c_str()));
693648

694-
// TODO(bgruber): convert to the new tuning API
695-
nlohmann::json partition_policy =
696-
cub::detail::ptx_json::parse("device_three_way_partition_policy", {result.data.get(), result.size});
697-
698-
using cub::detail::RuntimeThreeWayPartitionAgentPolicy;
699-
auto three_way_partition_policy =
700-
RuntimeThreeWayPartitionAgentPolicy::from_json(partition_policy, "ThreeWayPartitionPolicy");
701-
702649
build_ptr->cc = cc_major * 10 + cc_minor;
703650
build_ptr->large_segments_selector_op = large_selector_op;
704651
build_ptr->small_segments_selector_op = small_selector_op;
705652
build_ptr->cubin = (void*) result.data.release();
706653
build_ptr->cubin_size = result.size;
707654
build_ptr->key_type = keys_in_it.value_type;
708655
build_ptr->offset_type = cccl_type_info{sizeof(OffsetT), alignof(OffsetT), cccl_type_enum::CCCL_INT64};
709-
// Use the runtime policy extracted via from_json
710-
build_ptr->runtime_policy = new cub::detail::segmented_sort::policy_selector{policy_sel};
711-
build_ptr->partition_runtime_policy = new segmented_sort::partition_runtime_tuning_policy{three_way_partition_policy};
712-
build_ptr->order = sort_order;
656+
build_ptr->runtime_policy = new cub::detail::segmented_sort::policy_selector{policy_sel};
657+
build_ptr->partition_runtime_policy = new cub::detail::three_way_partition::policy_selector{partition_policy_sel};
658+
build_ptr->order = sort_order;
713659

714660
return CUDA_SUCCESS;
715661
}
@@ -798,34 +744,24 @@ CUresult cccl_device_segmented_sort_impl(
798744
cub::DoubleBuffer<indirect_arg_t> d_values_double_buffer(
799745
*static_cast<indirect_arg_t**>(&val_arg_in), *static_cast<indirect_arg_t**>(&val_arg_out));
800746

801-
// TODO(bgruber): remove all template arguments except the first two (the others can be deduced)
802-
auto exec_status = cub::detail::segmented_sort::dispatch<
803-
Order,
804-
OffsetT, // OffsetT
805-
indirect_arg_t, // KeyT
806-
indirect_arg_t, // ValueT
807-
indirect_iterator_t, // BeginOffsetIteratorT
808-
indirect_iterator_t, // EndOffsetIteratorT
809-
cub::detail::segmented_sort::policy_selector, // PolicySelector
810-
segmented_sort::segmented_sort_kernel_source, // KernelSource
811-
segmented_sort::partition_runtime_tuning_policy // PartitionPolicyHub
812-
>(d_temp_storage,
813-
*temp_storage_bytes,
814-
d_keys_double_buffer,
815-
d_values_double_buffer,
816-
num_items,
817-
num_segments,
818-
indirect_iterator_t{start_offset_in},
819-
indirect_iterator_t{end_offset_in},
820-
is_overwrite_okay,
821-
stream,
822-
/* policy_selector */
823-
*static_cast<cub::detail::segmented_sort::policy_selector*>(build.runtime_policy),
824-
/* partition_max_policy */
825-
*static_cast<segmented_sort::partition_runtime_tuning_policy*>(build.partition_runtime_policy),
826-
/* kernel_source */ segmented_sort::segmented_sort_kernel_source{build},
827-
/* partition_kernel_source */ segmented_sort::partition_kernel_source{build},
828-
/* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});
747+
auto exec_status = cub::detail::segmented_sort::dispatch<Order, OffsetT>(
748+
d_temp_storage,
749+
*temp_storage_bytes,
750+
d_keys_double_buffer,
751+
d_values_double_buffer,
752+
num_items,
753+
num_segments,
754+
indirect_iterator_t{start_offset_in},
755+
indirect_iterator_t{end_offset_in},
756+
is_overwrite_okay,
757+
stream,
758+
/* policy_selector */
759+
*static_cast<cub::detail::segmented_sort::policy_selector*>(build.runtime_policy),
760+
/* partition_policy_selector */
761+
*static_cast<cub::detail::three_way_partition::policy_selector*>(build.partition_runtime_policy),
762+
/* kernel_source */ segmented_sort::segmented_sort_kernel_source{build},
763+
/* partition_kernel_source */ segmented_sort::partition_kernel_source{build},
764+
/* launcher_factory */ cub::detail::CudaDriverLauncherFactory{cu_device, build.cc});
829765

830766
*selector = d_keys_double_buffer.selector;
831767
error = static_cast<CUresult>(exec_status);
@@ -909,8 +845,8 @@ try
909845
// Clean up the runtime policies
910846
std::unique_ptr<cub::detail::segmented_sort::policy_selector> rtp(
911847
static_cast<cub::detail::segmented_sort::policy_selector*>(build_ptr->runtime_policy));
912-
std::unique_ptr<segmented_sort::partition_runtime_tuning_policy> prtp(
913-
static_cast<segmented_sort::partition_runtime_tuning_policy*>(build_ptr->partition_runtime_policy));
848+
std::unique_ptr<cub::detail::three_way_partition::policy_selector> prtp(
849+
static_cast<cub::detail::three_way_partition::policy_selector*>(build_ptr->partition_runtime_policy));
914850
check(cuLibraryUnload(build_ptr->library));
915851

916852
return CUDA_SUCCESS;

0 commit comments

Comments
 (0)