File tree Expand file tree Collapse file tree 3 files changed +45
-21
lines changed
fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions Expand file tree Collapse file tree 3 files changed +45
-21
lines changed Original file line number Diff line number Diff line change @@ -28,12 +28,21 @@ at::Tensor dispatch_fp8_rowwise_kernel(
28
28
int M = size_to_dim_ (XQ.dim () - 1 , XQ.sizes ());
29
29
int N = size_to_dim_ (WQ.dim () - 1 , WQ.sizes ());
30
30
int K = XQ.size (-1 );
31
-
32
- int arch = 9 ;
33
- cudaDeviceProp prop;
34
- cudaGetDeviceProperties (&prop, 0 );
35
- if (prop.major >= 10 ) {
36
- arch = 10 ;
31
+ static int arch = -1 ;
32
+ // Avoid expensive cudaGetDeviceProperties call.
33
+ if (arch < 0 ) {
34
+ cudaDeviceProp prop;
35
+ cudaGetDeviceProperties (&prop, 0 );
36
+ if (prop.major >= 10 ) {
37
+ arch = 10 ;
38
+ int runtimeVersion;
39
+ cudaRuntimeGetVersion (&runtimeVersion);
40
+ TORCH_CHECK (
41
+ runtimeVersion >= 12080 ,
42
+ " FP8 GEMM on sm100a or above requires cuda >= 12.8" );
43
+ } else {
44
+ arch = 9 ;
45
+ }
37
46
}
38
47
39
48
// Use shape heuristics to dispatch to optimized kernel configuration.
Original file line number Diff line number Diff line change @@ -29,11 +29,21 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
29
29
bool use_fast_accum = true ,
30
30
std::optional<at::Tensor> bias = std::nullopt,
31
31
std::optional<at::Tensor> output = std::nullopt) {
32
- int arch = 9 ;
33
- cudaDeviceProp prop;
34
- cudaGetDeviceProperties (&prop, 0 );
35
- if (prop.major >= 10 ) {
36
- arch = 10 ;
32
+ static int arch = -1 ;
33
+ // Avoid expensive cudaGetDeviceProperties call.
34
+ if (arch < 0 ) {
35
+ cudaDeviceProp prop;
36
+ cudaGetDeviceProperties (&prop, 0 );
37
+ if (prop.major >= 10 ) {
38
+ arch = 10 ;
39
+ int runtimeVersion;
40
+ cudaRuntimeGetVersion (&runtimeVersion);
41
+ TORCH_CHECK (
42
+ runtimeVersion >= 12080 ,
43
+ " FP8 batched GEMM on sm100a or above requires cuda >= 12.8" );
44
+ } else {
45
+ arch = 9 ;
46
+ }
37
47
}
38
48
39
49
TORCH_CHECK (
Original file line number Diff line number Diff line change @@ -30,16 +30,21 @@ at::Tensor dispatch_fp8_grouped_kernel(
30
30
at::Tensor output,
31
31
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
32
32
std::optional<at::Tensor> M_sizes = std::nullopt) {
33
- int arch = 9 ;
34
- cudaDeviceProp prop;
35
- cudaGetDeviceProperties (&prop, 0 );
36
- if (prop.major >= 10 ) {
37
- arch = 10 ;
38
- int runtimeVersion;
39
- cudaRuntimeGetVersion (&runtimeVersion);
40
- TORCH_CHECK (
41
- runtimeVersion >= 12080 ,
42
- " FP8 grouped GEMM on blackwell sm100a requires cuda >= 12.8" );
33
+ static int arch = -1 ;
34
+ // Avoid expensive cudaGetDeviceProperties call.
35
+ if (arch < 0 ) {
36
+ cudaDeviceProp prop;
37
+ cudaGetDeviceProperties (&prop, 0 );
38
+ if (prop.major >= 10 ) {
39
+ arch = 10 ;
40
+ int runtimeVersion;
41
+ cudaRuntimeGetVersion (&runtimeVersion);
42
+ TORCH_CHECK (
43
+ runtimeVersion >= 12080 ,
44
+ " FP8 grouped GEMM on sm100a or above requires cuda >= 12.8" );
45
+ } else {
46
+ arch = 9 ;
47
+ }
43
48
}
44
49
45
50
// Use heuristics to pick the best kernel implementation.
You can’t perform that action at this time.
0 commit comments