@@ -977,8 +977,6 @@ private:
977977namespace detail ::segmented_sort
978978{
979979template <typename PartitionPolicyHub,
980- typename LargeKernelT,
981- typename SmallKernelT,
982980 typename KeyT,
983981 typename ValueT,
984982 typename BeginOffsetIteratorT,
@@ -988,8 +986,6 @@ template <typename PartitionPolicyHub,
988986 typename KernelLauncherFactory,
989987 typename GetFinalOutputOp>
990988CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_with_partitioning (
991- LargeKernelT large_kernel,
992- SmallKernelT small_kernel,
993989 global_segment_offset_t num_segments,
994990 ::cuda::std::int64_t num_items,
995991 BeginOffsetIteratorT d_begin_offsets,
@@ -1082,15 +1078,15 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_
10821078 launcher_factory (1 , 1 , 0 , stream)
10831079 .doit (
10841080 detail::segmented_sort::DeviceSegmentedSortContinuationKernel<
1085- LargeKernelT ,
1086- SmallKernelT ,
1081+ decltype (kernel_source. SegmentedSortKernelLarge ()) ,
1082+ decltype (kernel_source. SegmentedSortKernelSmall ()) ,
10871083 KeyT,
10881084 ValueT,
10891085 BeginOffsetIteratorT,
10901086 EndOffsetIteratorT,
10911087 KernelLauncherFactory>,
1092- large_kernel ,
1093- small_kernel ,
1088+ kernel_source. SegmentedSortKernelLarge () ,
1089+ kernel_source. SegmentedSortKernelSmall () ,
10941090 current_num_segments,
10951091 d_keys.Current (),
10961092 get_final_output (d_keys, active_policy.large_segment .radix_bits ),
@@ -1138,8 +1134,8 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_
11381134 }
11391135
11401136 if (const auto error = detail::segmented_sort::device_segmented_sort_continuation (
1141- large_kernel ,
1142- small_kernel ,
1137+ kernel_source. SegmentedSortKernelLarge () ,
1138+ kernel_source. SegmentedSortKernelSmall () ,
11431139 current_num_segments,
11441140 d_keys.Current (),
11451141 get_final_output (d_keys, active_policy.large_segment .radix_bits ),
@@ -1175,13 +1171,12 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_
11751171
11761172template <typename KeyT,
11771173 typename ValueT,
1178- typename FallbackKernelT,
11791174 typename BeginOffsetIteratorT,
11801175 typename EndOffsetIteratorT,
1176+ typename KernelSource,
11811177 typename KernelLauncherFactory,
11821178 typename GetFinalOutputOp>
11831179CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_without_partitioning (
1184- FallbackKernelT fallback_kernel,
11851180 global_segment_offset_t num_segments,
11861181 BeginOffsetIteratorT d_begin_offsets,
11871182 EndOffsetIteratorT d_end_offsets,
@@ -1190,6 +1185,7 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_
11901185 DoubleBuffer<ValueT>& d_values,
11911186 device_double_buffer<KeyT>& d_keys_double_buffer,
11921187 device_double_buffer<ValueT>& d_values_double_buffer,
1188+ KernelSource& kernel_source,
11931189 KernelLauncherFactory& launcher_factory,
11941190 const segmented_sort_policy& active_policy,
11951191 GetFinalOutputOp&& get_final_output)
@@ -1207,7 +1203,7 @@ CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t sort_
12071203
12081204 if (const auto error = CubDebug (
12091205 launcher_factory (blocks_in_grid, threads_in_block, 0 , stream)
1210- .doit (fallback_kernel ,
1206+ .doit (kernel_source. SegmentedSortFallbackKernel () ,
12111207 d_keys.Current (),
12121208 get_final_output (d_keys, active_policy.large_segment .radix_bits ),
12131209 d_keys_double_buffer,
@@ -1494,15 +1490,9 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch(
14941490 : (is_num_passes_odd) ? values_allocation.get ()
14951491 : d_values.Alternate ());
14961492
1497- const auto segmented_sort_fallback_kernel = kernel_source.SegmentedSortFallbackKernel ();
1498- const auto segmented_sort_kernel_small = kernel_source.SegmentedSortKernelSmall ();
1499- const auto segmented_sort_kernel_large = kernel_source.SegmentedSortKernelLarge ();
1500-
15011493 if (partition_segments)
15021494 {
15031495 if (const auto error = sort_with_partitioning<PartitionPolicyHub>(
1504- segmented_sort_kernel_large,
1505- segmented_sort_kernel_small,
15061496 num_segments,
15071497 num_items,
15081498 d_begin_offsets,
@@ -1532,7 +1522,6 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch(
15321522 else
15331523 {
15341524 if (const auto error = sort_without_partitioning (
1535- segmented_sort_fallback_kernel,
15361525 num_segments,
15371526 d_begin_offsets,
15381527 d_end_offsets,
@@ -1541,6 +1530,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE auto dispatch(
15411530 d_values,
15421531 d_keys_double_buffer,
15431532 d_values_double_buffer,
1533+ kernel_source,
15441534 launcher_factory,
15451535 active_policy,
15461536 get_final_output))
0 commit comments