@@ -23,12 +23,50 @@ pub struct LaunchConfig {
2323 pub shared_mem_bytes : u32 ,
2424}
2525
26+ /// Shared memory configuration for the calculation of the kernel launch configuration.
27+ /// See [LaunchConfig::suggested] for info.
28+ ///
29+ /// # Safety
30+ ///
31+ /// The `Dynamic` variant contains an unsafe `extern "C"` function to calculate the shared memory
32+ /// size in bytes for a given block size.
33+ /// This function is passed directly to the CUDA driver.
34+ /// The caller must guarantee that this function returns valid smem values for all reasonable
35+ /// block sizes.
36+ #[ derive( Clone , Copy , Debug ) ]
37+ pub enum SharedMemoryConfig {
38+ Fixed ( usize ) ,
39+ Dynamic ( unsafe extern "C" fn ( block_size : std:: ffi:: c_int ) -> usize ) ,
40+ }
41+
42+ impl SharedMemoryConfig {
43+ /// For functions with no shared memory.
44+ pub fn none ( ) -> Self {
45+ Self :: Fixed ( 0 )
46+ }
47+
48+ fn with_block_size ( & self , block_size : u32 ) -> u32 {
49+ match self {
50+ Self :: Fixed ( val) => {
51+ debug_assert ! ( * val <= u32 :: MAX as usize , "shared memory size exceeds u32::MAX" ) ;
52+ * val as u32
53+ }
54+ Self :: Dynamic ( func) => unsafe {
55+ let smem = func ( block_size as std:: ffi:: c_int ) ;
56+ debug_assert ! ( smem <= u32 :: MAX as usize , "dynamic shared memory size exceeds u32::MAX" ) ;
57+ smem as u32
58+ } ,
59+ }
60+ }
61+ }
62+
2663impl LaunchConfig {
2764 /// Creates a [LaunchConfig] with:
2865 /// - block_dim == `1024`
2966 /// - grid_dim == `(n + 1023) / 1024`
3067 /// - shared_mem_bytes == `0`
3168 pub fn for_num_elems ( n : u32 ) -> Self {
69+ debug_assert ! ( n > 0 , "n must be greater than 0" ) ;
3270 const NUM_THREADS : u32 = 1024 ;
3371 let num_blocks = n. div_ceil ( NUM_THREADS ) ;
3472 Self {
@@ -37,6 +75,67 @@ impl LaunchConfig {
3775 shared_mem_bytes : 0 ,
3876 }
3977 }
78+
79+ pub fn for_block_size ( n : u32 , block_size : u32 , smem : SharedMemoryConfig ) -> Self {
80+ debug_assert ! ( n > 0 , "n must be greater than 0" ) ;
81+ debug_assert ! ( block_size > 0 , "block size must be greater than 0" ) ;
82+ let num_blocks = n. div_ceil ( block_size) ;
83+ Self {
84+ grid_dim : ( num_blocks, 1 , 1 ) ,
85+ block_dim : ( block_size, 1 , 1 ) ,
86+ shared_mem_bytes : smem. with_block_size ( block_size) ,
87+ }
88+ }
89+
90+ /// Calculates a launch configuration that _should_ yield a reasonable occupancy on the GPU.
91+ ///
92+ /// # Performance Considerations
93+ ///
94+ /// Note that the values returned by this function are based on calculations done by the
95+ /// driver, provided the loadout of the cuda function, the shared memory specifications, and
96+ /// current hardware.
97+ /// In many cases the configuration provided by this will *not* be the absolute optimum, as
98+ /// GPU performance can be very unpredictable, especially if scheduling of multiple concurrent
99+ /// kernels becomes important.
100+ /// Always benchmark your kernels if you want optimal performance!
101+ /// This is more of a 'good enough for most cases' situation.
102+ pub fn suggested (
103+ n : u32 ,
104+ func : & CudaFunction ,
105+ block_size_limit : Option < u32 > ,
106+ smem : SharedMemoryConfig ,
107+ ) -> Result < Self , DriverError > {
108+ debug_assert ! ( n > 0 , "n must be greater than 0" ) ;
109+ let ( min_grid_size, block_size, shared_mem_bytes) = match smem {
110+ SharedMemoryConfig :: Fixed ( smem_size) => {
111+ let ( g, b) = func. occupancy_max_potential_block_size (
112+ None ,
113+ smem_size,
114+ block_size_limit. unwrap_or ( 0 ) ,
115+ None ,
116+ ) ?;
117+ debug_assert ! ( smem_size <= u32 :: MAX as usize , "shared memory size exceeds u32::MAX" ) ;
118+ ( g, b, smem_size as u32 )
119+ }
120+ SharedMemoryConfig :: Dynamic ( block_size_to_smem_size) => {
121+ let ( g, b) = func. occupancy_max_potential_block_size (
122+ Some ( block_size_to_smem_size) ,
123+ 0 ,
124+ block_size_limit. unwrap_or ( 0 ) ,
125+ None ,
126+ ) ?;
127+ let smem = unsafe { block_size_to_smem_size ( b as std:: ffi:: c_int ) } ;
128+ debug_assert ! ( smem <= u32 :: MAX as usize , "dynamic shared memory size exceeds u32::MAX" ) ;
129+ ( g, b, smem as u32 )
130+ }
131+ } ;
132+ let grid_size = u32:: max ( min_grid_size, n. div_ceil ( block_size) ) ;
133+ Ok ( Self {
134+ block_dim : ( block_size, 1 , 1 ) ,
135+ grid_dim : ( grid_size, 1 , 1 ) ,
136+ shared_mem_bytes,
137+ } )
138+ }
40139}
41140
42141/// The kernel launch builder. Instantiate with [CudaStream::launch_builder()], and then
0 commit comments