Skip to content

Commit eeb6898

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Optimize cudaGetDeviceProperties runtime overhead
Summary: X-link: facebookresearch/FBGEMM#1284 Further optimize FP8 kernels runtime overhead with `cudaGetDeviceProperties` by only triggering it once Before this Diff: [Trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2F0%2F1748487716%2Flocalhost%2Flibkineto_activities_3431969.json.gz&bucket=gpu_traces) After this Diff: [Trace](https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree%2Ftraces%2Fdynocli%2F0%2F1748488054%2Flocalhost%2Flibkineto_activities_3821152.json.gz&bucket=gpu_traces) Differential Revision: D75574880
1 parent 10bf7c1 commit eeb6898

File tree

3 files changed

+45
-21
lines changed

3 files changed

+45
-21
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,21 @@ at::Tensor dispatch_fp8_rowwise_kernel(
2828
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
2929
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
3030
int K = XQ.size(-1);
31-
32-
int arch = 9;
33-
cudaDeviceProp prop;
34-
cudaGetDeviceProperties(&prop, 0);
35-
if (prop.major >= 10) {
36-
arch = 10;
31+
static int arch = -1;
32+
// Avoid expensive cudaGetDeviceProperties call.
33+
if (arch < 0) {
34+
cudaDeviceProp prop;
35+
cudaGetDeviceProperties(&prop, 0);
36+
if (prop.major >= 10) {
37+
arch = 10;
38+
int runtimeVersion;
39+
cudaRuntimeGetVersion(&runtimeVersion);
40+
TORCH_CHECK(
41+
runtimeVersion >= 12080,
42+
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
43+
} else {
44+
arch = 9;
45+
}
3746
}
3847

3948
// Use shape heuristics to dispatch to optimized kernel configuration.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,21 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
2929
bool use_fast_accum = true,
3030
std::optional<at::Tensor> bias = std::nullopt,
3131
std::optional<at::Tensor> output = std::nullopt) {
32-
int arch = 9;
33-
cudaDeviceProp prop;
34-
cudaGetDeviceProperties(&prop, 0);
35-
if (prop.major >= 10) {
36-
arch = 10;
32+
static int arch = -1;
33+
// Avoid expensive cudaGetDeviceProperties call.
34+
if (arch < 0) {
35+
cudaDeviceProp prop;
36+
cudaGetDeviceProperties(&prop, 0);
37+
if (prop.major >= 10) {
38+
arch = 10;
39+
int runtimeVersion;
40+
cudaRuntimeGetVersion(&runtimeVersion);
41+
TORCH_CHECK(
42+
runtimeVersion >= 12080,
43+
"FP8 batched GEMM on sm100a or above requires cuda >= 12.8");
44+
} else {
45+
arch = 9;
46+
}
3747
}
3848

3949
TORCH_CHECK(

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,21 @@ at::Tensor dispatch_fp8_grouped_kernel(
3030
at::Tensor output,
3131
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
3232
std::optional<at::Tensor> M_sizes = std::nullopt) {
33-
int arch = 9;
34-
cudaDeviceProp prop;
35-
cudaGetDeviceProperties(&prop, 0);
36-
if (prop.major >= 10) {
37-
arch = 10;
38-
int runtimeVersion;
39-
cudaRuntimeGetVersion(&runtimeVersion);
40-
TORCH_CHECK(
41-
runtimeVersion >= 12080,
42-
"FP8 grouped GEMM on blackwell sm100a requires cuda >= 12.8");
33+
static int arch = -1;
34+
// Avoid expensive cudaGetDeviceProperties call.
35+
if (arch < 0) {
36+
cudaDeviceProp prop;
37+
cudaGetDeviceProperties(&prop, 0);
38+
if (prop.major >= 10) {
39+
arch = 10;
40+
int runtimeVersion;
41+
cudaRuntimeGetVersion(&runtimeVersion);
42+
TORCH_CHECK(
43+
runtimeVersion >= 12080,
44+
"FP8 grouped GEMM on sm100a or above requires cuda >= 12.8");
45+
} else {
46+
arch = 9;
47+
}
4348
}
4449

4550
// Use heuristics to pick the best kernel implementation.

0 commit comments

Comments
 (0)