Skip to content

Commit 7caa61f

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Optimize cudaGetDeviceProperties runtime overhead (#4209)
Summary: Pull Request resolved: #4209 X-link: facebookresearch/FBGEMM#1284 Further optimize FP8 kernels runtime overhead with `cudaGetDeviceProperties` by only triggering it once Before this Diff: [Trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2F0%2F1748487716%2Flocalhost%2Flibkineto_activities_3431969.json.gz&bucket=gpu_traces) After this Diff: [Trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2F0%2F1748488054%2Flocalhost%2Flibkineto_activities_3821152.json.gz&bucket=gpu_traces) Differential Revision: D75574880
1 parent 10bf7c1 commit 7caa61f

File tree

3 files changed

+48
-21
lines changed

3 files changed

+48
-21
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
// clang-format on
1213

1314
#include "f8f8bf16_rowwise/f8f8bf16_rowwise_manifest.cuh"
@@ -28,12 +29,21 @@ at::Tensor dispatch_fp8_rowwise_kernel(
2829
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
2930
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
3031
int K = XQ.size(-1);
31-
32-
int arch = 9;
33-
cudaDeviceProp prop;
34-
cudaGetDeviceProperties(&prop, 0);
35-
if (prop.major >= 10) {
36-
arch = 10;
32+
static int arch = -1;
33+
// Avoid expensive cudaGetDeviceProperties call.
34+
if (arch < 0) {
35+
cudaDeviceProp prop;
36+
cudaGetDeviceProperties(&prop, 0);
37+
if (prop.major >= 10) {
38+
arch = 10;
39+
int runtimeVersion;
40+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
41+
TORCH_CHECK(
42+
runtimeVersion >= 12080,
43+
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
44+
} else {
45+
arch = 9;
46+
}
3747
}
3848

3949
// Use shape heuristics to dispatch to optimized kernel configuration.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <c10/cuda/CUDAGuard.h>
910
#include <cute/tensor.hpp>
1011
#include "f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_manifest.cuh"
1112

@@ -29,11 +30,21 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
2930
bool use_fast_accum = true,
3031
std::optional<at::Tensor> bias = std::nullopt,
3132
std::optional<at::Tensor> output = std::nullopt) {
32-
int arch = 9;
33-
cudaDeviceProp prop;
34-
cudaGetDeviceProperties(&prop, 0);
35-
if (prop.major >= 10) {
36-
arch = 10;
33+
static int arch = -1;
34+
// Avoid expensive cudaGetDeviceProperties call.
35+
if (arch < 0) {
36+
cudaDeviceProp prop;
37+
cudaGetDeviceProperties(&prop, 0);
38+
if (prop.major >= 10) {
39+
arch = 10;
40+
int runtimeVersion;
41+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
42+
TORCH_CHECK(
43+
runtimeVersion >= 12080,
44+
"FP8 batched GEMM on sm100a or above requires cuda >= 12.8");
45+
} else {
46+
arch = 9;
47+
}
3748
}
3849

3950
TORCH_CHECK(

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <ATen/ATen.h>
1010
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
1112
// clang-format on
1213

1314
#include "f8f8bf16_rowwise_grouped/f8f8bf16_rowwise_grouped_manifest.cuh"
@@ -30,16 +31,21 @@ at::Tensor dispatch_fp8_grouped_kernel(
3031
at::Tensor output,
3132
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
3233
std::optional<at::Tensor> M_sizes = std::nullopt) {
33-
int arch = 9;
34-
cudaDeviceProp prop;
35-
cudaGetDeviceProperties(&prop, 0);
36-
if (prop.major >= 10) {
37-
arch = 10;
38-
int runtimeVersion;
39-
cudaRuntimeGetVersion(&runtimeVersion);
40-
TORCH_CHECK(
41-
runtimeVersion >= 12080,
42-
"FP8 grouped GEMM on blackwell sm100a requires cuda >= 12.8");
34+
static int arch = -1;
35+
// Avoid expensive cudaGetDeviceProperties call.
36+
if (arch < 0) {
37+
cudaDeviceProp prop;
38+
cudaGetDeviceProperties(&prop, 0);
39+
if (prop.major >= 10) {
40+
arch = 10;
41+
int runtimeVersion;
42+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
43+
TORCH_CHECK(
44+
runtimeVersion >= 12080,
45+
"FP8 grouped GEMM on sm100a or above requires cuda >= 12.8");
46+
} else {
47+
arch = 9;
48+
}
4349
}
4450

4551
// Use heuristics to pick the best kernel implementation.

0 commit comments

Comments
 (0)