@@ -589,6 +589,7 @@ struct policy_hub
589589
590590 static constexpr int squad_reduce_thread_count = num_reduce_warps * num_threads_per_warp;
591591
592+ // manual tuning based on cub.bench.scan.exclusive.sum.base
592593 // 256 / sizeof(InputValueT) - 1 should minimize bank conflicts (and fits into 48KiB SMEM)
593594 // 2-byte types and double needed special handling
594595 static constexpr int items_per_thread =
@@ -637,7 +638,35 @@ struct policy_hub
637638#endif // __cccl_ptx_isa >= 860
638639 };
639640
640- using MaxPolicy = Policy1000;
641+ struct Policy1200 : ChainedPolicy<1200 , Policy1200, Policy1000>
642+ {
643+ using ScanPolicyT = typename Policy1000::ScanPolicyT;
644+
645+ struct WarpspeedPolicy : Policy1000::WarpspeedPolicy
646+ {
647+ static constexpr int items_per_thread = [] {
648+ auto ipt = Policy1000::WarpspeedPolicy::items_per_thread;
649+
650+ // based on cub.bench.scan.exclusive.custom.base, cap items per thread if we don't know the scan op
651+ if (is_primitive_op<ScanOpT>() == primitive_op::no && ::cuda::std::is_arithmetic_v<InputValueT>)
652+ {
653+ if (sizeof (InputValueT) == 4 || sizeof (InputValueT) == 8 )
654+ {
655+ return 127 ;
656+ }
657+
658+ const int max = sizeof (InputValueT) <= 2 ? 63 : 127 ;
659+ ipt = ::cuda::std::min (ipt, max);
660+ }
661+
662+ return ipt;
663+ }();
664+
665+ static constexpr int tile_size = items_per_thread * Policy1000::WarpspeedPolicy::squad_reduce_thread_count;
666+ };
667+ };
668+
669+ using MaxPolicy = Policy1200;
641670};
642671} // namespace detail::scan
643672
0 commit comments