@@ -223,22 +223,18 @@ class W4A8MoeGemmUniversalBase {
223223 static Status can_implement (Arguments const &args)
224224 {
225225 CUTLASS_TRACE_HOST (" W4A8MoeGemmUniversalBase::can_implement()" );
226- // printf("--1\n");
227226 // Initialize static kernel and device properties, if necessary.
228227 Status result = init_device_props ();
229- // printf("--1-2\n");
230228 if (result != Status::kSuccess ) {
231229 return result;
232230 }
233- // printf("--2\n");
234231 dim3 grid = get_grid_shape (args);
235232 // printf("--grid:%d, %d, %d\n", grid.x, grid.y, grid.z);
236233 if (!(grid.y <= std::numeric_limits<uint16_t >::max () &&
237234 grid.z <= std::numeric_limits<uint16_t >::max ()))
238235 {
239236 return Status::kErrorInvalidProblem ;
240237 }
241- // printf("--3\n");
242238 return GemmKernel::can_implement (args);
243239 }
244240
@@ -285,18 +281,50 @@ class W4A8MoeGemmUniversalBase {
285281 }
286282
287283
284+
288285 // / Returns the maximum number of active thread blocks per multiprocessor
289- static int maximum_active_blocks ()
286+ static int maximum_active_blocks (int smem_capacity = - 1 )
290287 {
291288 CUTLASS_TRACE_HOST (" W4A8MoeGemmUniversalBase::maximum_active_blocks()" );
292289
293- // Initialize static device properties, if necessary
294- if (init_device_props () != Status::kSuccess ) {
290+ int smem_size = int (sizeof (typename GemmKernel_::SharedStorage));
291+
292+ CUTLASS_TRACE_HOST (" smem_size: " << smem_size << " bytes" );
293+
294+ cudaError_t result;
295+ if (smem_size > (48 << 10 )) {
296+ result = cudaFuncSetAttribute (Kernel2<GemmKernel_>,
297+ cudaFuncAttributeMaxDynamicSharedMemorySize,
298+ smem_size);
299+
300+ if (result != cudaSuccess) {
301+ // Call cudaGetLastError() to clear the error bit
302+ result = cudaGetLastError ();
303+ CUTLASS_TRACE_HOST (
304+ " cudaFuncSetAttribute() returned error "
305+ << cudaGetErrorString (result));
306+ return -1 ;
307+ }
308+ }
309+
310+ int max_active_blocks = -1 ;
311+ result = cudaOccupancyMaxActiveBlocksPerMultiprocessor (
312+ &max_active_blocks,
313+ Kernel2<GemmKernel_>,
314+ GemmKernel_::kThreadCount ,
315+ smem_size);
316+
317+ if (result != cudaSuccess) {
318+ // Call cudaGetLastError() to clear the error bit
319+ result = cudaGetLastError ();
320+ CUTLASS_TRACE_HOST (
321+ " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
322+ << cudaGetErrorString (result));
295323 return -1 ;
296324 }
297325
298- CUTLASS_TRACE_HOST (" max_active_blocks: " << sm_occupancy_ );
299- return sm_occupancy_ ;
326+ CUTLASS_TRACE_HOST (" max_active_blocks: " << max_active_blocks );
327+ return max_active_blocks ;
300328 }
301329
302330
@@ -341,8 +369,7 @@ class W4A8MoeGemmUniversalBase {
341369
342370 // Configure grid and block dimensions
343371 dim3 block (GemmKernel::kThreadCount , 1 , 1 );
344- // dim3 grid = params_.get_grid_dims();
345- dim3 grid (216 , 1 , 1 );
372+ dim3 grid (params_.threadblock_count , 1 , 1 );
346373
347374 // Launch kernel
348375 CUTLASS_TRACE_HOST (" "
0 commit comments