Skip to content

Commit 7a94110

Browse files
Improve warpspeed scan tuning for sm120
Fixes: #7813
1 parent f51b4a0 commit 7a94110

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

cub/cub/device/dispatch/tuning/tuning_scan.cuh

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ struct policy_hub
589589

590590
static constexpr int squad_reduce_thread_count = num_reduce_warps * num_threads_per_warp;
591591

592+
// manual tuning based on cub.bench.scan.exclusive.sum.base
592593
// 256 / sizeof(InputValueT) - 1 should minimize bank conflicts (and fits into 48KiB SMEM)
593594
// 2-byte types and double needed special handling
594595
static constexpr int items_per_thread =
@@ -637,7 +638,35 @@ struct policy_hub
637638
#endif // __cccl_ptx_isa >= 860
638639
};
639640

640-
using MaxPolicy = Policy1000;
641+
struct Policy1200 : ChainedPolicy<1200, Policy1200, Policy1000>
642+
{
643+
using ScanPolicyT = typename Policy1000::ScanPolicyT;
644+
645+
struct WarpspeedPolicy : Policy1000::WarpspeedPolicy
646+
{
647+
static constexpr int items_per_thread = [] {
648+
auto ipt = Policy1000::WarpspeedPolicy::items_per_thread;
649+
650+
// based on cub.bench.scan.exclusive.custom.base, cap items per thread if we don't know the scan op
651+
if (is_primitive_op<ScanOpT>() == primitive_op::no && ::cuda::std::is_arithmetic_v<InputValueT>)
652+
{
653+
if (sizeof(InputValueT) == 4 || sizeof(InputValueT) == 8)
654+
{
655+
return 127;
656+
}
657+
658+
const int max = sizeof(InputValueT) <= 2 ? 63 : 127;
659+
ipt = ::cuda::std::min(ipt, max);
660+
}
661+
662+
return ipt;
663+
}();
664+
665+
static constexpr int tile_size = items_per_thread * Policy1000::WarpspeedPolicy::squad_reduce_thread_count;
666+
};
667+
};
668+
669+
using MaxPolicy = Policy1200;
641670
};
642671
} // namespace detail::scan
643672

0 commit comments

Comments
 (0)