Skip to content

Commit 6f56614

Browse files
zhyncsyizhang2077
andauthored
chore: upgrade cutlass 3.9.2 (#6004)
Co-authored-by: yizhang2077 <[email protected]>
1 parent bdd1799 commit 6f56614

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ include(FetchContent)
4545
FetchContent_Declare(
4646
repo-cutlass
4747
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
48-
GIT_TAG e94e888df3551224738bfa505787b515eae8352f
48+
GIT_TAG ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e
4949
GIT_SHALLOW OFF
5050
)
5151
FetchContent_Populate(repo-cutlass)

sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -384,29 +384,30 @@ torch::Tensor fp8_blockwise_scaled_mm(
384384

385385
auto sm_version = getSMVersion();
386386

387+
int64_t original_rows = mat_a.size(0);
388+
torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4);
389+
torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true);
390+
torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options());
391+
387392
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
388393
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
389394
if (sm_version == 90) {
390395
torch::Tensor scales_b_contiguous = scales_b.contiguous();
391396
if (out_dtype == torch::kBFloat16) {
392-
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
397+
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
398+
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
393399
} else {
394-
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
400+
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(
401+
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
395402
}
396-
return out;
403+
return out_padded.slice(0, 0, original_rows);
397404
}
398405
#endif
399406
#endif
400407

401408
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
402409
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
403410
if (sm_version == 100) {
404-
int64_t original_rows = mat_a.size(0);
405-
406-
torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4);
407-
torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true);
408-
torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options());
409-
410411
if (out_dtype == torch::kBFloat16) {
411412
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
412413
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);

0 commit comments

Comments
 (0)