Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,15 @@ struct policy_hub
static constexpr int num_threads_per_warp = 32;

// TODO(bgruber): tune this
static constexpr int num_reduce_warps = 4;
static constexpr int num_scan_stor_warps = 4;
# if _CCCL_COMPILER(NVHPC)
// need to reduce the number of threads to <= 256, so each thread can use up to 255 registers. This avoids an
// error in ptxas, see also: https://github.com/NVIDIA/cccl/issues/7700.
static constexpr int num_reduce_warps = 2;
static constexpr int num_scan_stor_warps = 2;
# else // _CCCL_COMPILER(NVHPC)
static constexpr int num_reduce_warps = 4;
static constexpr int num_scan_stor_warps = 4;
# endif // _CCCL_COMPILER(NVHPC)
static constexpr int num_load_warps = 1;
static constexpr int num_sched_warps = 1;
static constexpr int num_look_ahead_warps = 1;
Expand All @@ -587,6 +594,10 @@ struct policy_hub
num_reduce_warps + num_scan_stor_warps + num_load_warps + num_sched_warps + num_look_ahead_warps;
static constexpr int num_total_threads = num_total_warps * num_threads_per_warp;

# if _CCCL_COMPILER(NVHPC)
static_assert(num_total_threads <= 256);
# endif // _CCCL_COMPILER(NVHPC)

static constexpr int squad_reduce_thread_count = num_reduce_warps * num_threads_per_warp;

// 256 / sizeof(InputValueT) - 1 should minimize bank conflicts (and fits into 48KiB SMEM)
Expand Down
Loading