@@ -384,29 +384,30 @@ torch::Tensor fp8_blockwise_scaled_mm(
384384
385385 auto sm_version = getSMVersion ();
386386
387+ int64_t original_rows = mat_a.size (0 );
388+ torch::Tensor mat_a_padded = pad_tensor (mat_a, /* alignment=*/ 4 );
389+ torch::Tensor scales_a_padded = pad_tensor (scales_a, /* alignment=*/ 4 , /* col_major=*/ true );
390+ torch::Tensor out_padded = torch::empty ({mat_a_padded.size (0 ), mat_b.size (1 )}, out.options ());
391+
387392#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
388393#if defined CUDA_VERSION && CUDA_VERSION >= 12000
389394 if (sm_version == 90 ) {
390395 torch::Tensor scales_b_contiguous = scales_b.contiguous ();
391396 if (out_dtype == torch::kBFloat16 ) {
392- sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t >(out, mat_a, mat_b, scales_a, scales_b_contiguous);
397+ sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t >(
398+ out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
393399 } else {
394- sm90_fp8_blockwise_dispatch_shape<cutlass::half_t >(out, mat_a, mat_b, scales_a, scales_b_contiguous);
400+ sm90_fp8_blockwise_dispatch_shape<cutlass::half_t >(
401+ out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
395402 }
396- return out ;
403+ return out_padded. slice ( 0 , 0 , original_rows) ;
397404 }
398405#endif
399406#endif
400407
401408#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
402409#if defined CUDA_VERSION && CUDA_VERSION >= 12080
403410 if (sm_version == 100 ) {
404- int64_t original_rows = mat_a.size (0 );
405-
406- torch::Tensor mat_a_padded = pad_tensor (mat_a, /* alignment=*/ 4 );
407- torch::Tensor scales_a_padded = pad_tensor (scales_a, /* alignment=*/ 4 , /* col_major=*/ true );
408- torch::Tensor out_padded = torch::empty ({mat_a_padded.size (0 ), mat_b.size (1 )}, out.options ());
409-
410411 if (out_dtype == torch::kBFloat16 ) {
411412 sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t >(
412413 out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
0 commit comments